mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
Support chaining for agent callbacks
(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756359693
This commit is contained in:
committed by
Copybara-Service
parent
a61d20e3df
commit
d45084f311
@@ -15,12 +15,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Awaitable, Union
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import final
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from opentelemetry import trace
|
||||
@@ -29,6 +31,7 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import override
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ..events.event import Event
|
||||
from .callback_context import CallbackContext
|
||||
@@ -38,14 +41,19 @@ if TYPE_CHECKING:
|
||||
|
||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||
|
||||
BeforeAgentCallback = Callable[
|
||||
_SingleAgentCallback: TypeAlias = Callable[
|
||||
[CallbackContext],
|
||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||
]
|
||||
|
||||
AfterAgentCallback = Callable[
|
||||
[CallbackContext],
|
||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||
BeforeAgentCallback: TypeAlias = Union[
|
||||
_SingleAgentCallback,
|
||||
list[_SingleAgentCallback],
|
||||
]
|
||||
|
||||
AfterAgentCallback: TypeAlias = Union[
|
||||
_SingleAgentCallback,
|
||||
list[_SingleAgentCallback],
|
||||
]
|
||||
|
||||
|
||||
@@ -85,7 +93,10 @@ class BaseAgent(BaseModel):
|
||||
"""The sub-agents of this agent."""
|
||||
|
||||
before_agent_callback: Optional[BeforeAgentCallback] = None
|
||||
"""Callback signature that is invoked before the agent run.
|
||||
"""Callback or list of callbacks to be invoked before the agent run.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
@@ -96,7 +107,10 @@ class BaseAgent(BaseModel):
|
||||
provided content will be returned to user.
|
||||
"""
|
||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
"""Callback or list of callbacks to be invoked after the agent run.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
@@ -236,6 +250,30 @@ class BaseAgent(BaseModel):
|
||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||
return invocation_context
|
||||
|
||||
@property
|
||||
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
|
||||
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.before_agent_callback:
|
||||
return []
|
||||
if isinstance(self.before_agent_callback, list):
|
||||
return self.before_agent_callback
|
||||
return [self.before_agent_callback]
|
||||
|
||||
@property
|
||||
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
|
||||
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.after_agent_callback:
|
||||
return []
|
||||
if isinstance(self.after_agent_callback, list):
|
||||
return self.after_agent_callback
|
||||
return [self.after_agent_callback]
|
||||
|
||||
async def __handle_before_agent_callback(
|
||||
self, ctx: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
@@ -246,27 +284,27 @@ class BaseAgent(BaseModel):
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.before_agent_callback, Callable):
|
||||
if not self.canonical_before_agent_callbacks:
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(ctx)
|
||||
before_agent_callback_content = self.before_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if inspect.isawaitable(before_agent_callback_content):
|
||||
before_agent_callback_content = await before_agent_callback_content
|
||||
|
||||
if before_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
content=before_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
for callback in self.canonical_before_agent_callbacks:
|
||||
before_agent_callback_content = callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
ctx.end_invocation = True
|
||||
return ret_event
|
||||
if inspect.isawaitable(before_agent_callback_content):
|
||||
before_agent_callback_content = await before_agent_callback_content
|
||||
if before_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
content=before_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
ctx.end_invocation = True
|
||||
return ret_event
|
||||
|
||||
if callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
@@ -288,18 +326,26 @@ class BaseAgent(BaseModel):
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.after_agent_callback, Callable):
|
||||
if not self.canonical_after_agent_callbacks:
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
after_agent_callback_content = self.after_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if inspect.isawaitable(after_agent_callback_content):
|
||||
after_agent_callback_content = await after_agent_callback_content
|
||||
for callback in self.canonical_after_agent_callbacks:
|
||||
after_agent_callback_content = callback(callback_context=callback_context)
|
||||
if inspect.isawaitable(after_agent_callback_content):
|
||||
after_agent_callback_content = await after_agent_callback_content
|
||||
if after_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
branch=invocation_context.branch,
|
||||
content=after_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
return ret_event
|
||||
|
||||
if after_agent_callback_content or callback_context.state.has_delta():
|
||||
if callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
|
||||
@@ -191,14 +191,19 @@ class BaseLlmFlow(ABC):
|
||||
llm_request: LlmRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Receive data from model and process events using BaseLlmConnection."""
|
||||
|
||||
def get_author(llm_response):
|
||||
"""Get the author of the event.
|
||||
|
||||
When the model returns transcription, the author is "user". Otherwise, the
|
||||
author is the agent.
|
||||
"""
|
||||
if llm_response and llm_response.content and llm_response.content.role == "user":
|
||||
return "user"
|
||||
if (
|
||||
llm_response
|
||||
and llm_response.content
|
||||
and llm_response.content.role == 'user'
|
||||
):
|
||||
return 'user'
|
||||
else:
|
||||
return invocation_context.agent.name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user