Support chaining for agent callbacks

(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756359693
This commit is contained in:
Selcuk Gun
2025-05-08 10:08:51 -07:00
committed by Copybara-Service
parent a61d20e3df
commit d45084f311
6 changed files with 339 additions and 454 deletions

View File

@@ -15,12 +15,14 @@
from __future__ import annotations
import inspect
from typing import Any, Awaitable, Union
from typing import Any
from typing import AsyncGenerator
from typing import Awaitable
from typing import Callable
from typing import final
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from google.genai import types
from opentelemetry import trace
@@ -29,6 +31,7 @@ from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from typing_extensions import override
from typing_extensions import TypeAlias
from ..events.event import Event
from .callback_context import CallbackContext
@@ -38,14 +41,19 @@ if TYPE_CHECKING:
tracer = trace.get_tracer('gcp.vertex.agent')
BeforeAgentCallback = Callable[
_SingleAgentCallback: TypeAlias = Callable[
[CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
]
AfterAgentCallback = Callable[
[CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
BeforeAgentCallback: TypeAlias = Union[
_SingleAgentCallback,
list[_SingleAgentCallback],
]
AfterAgentCallback: TypeAlias = Union[
_SingleAgentCallback,
list[_SingleAgentCallback],
]
@@ -85,7 +93,10 @@ class BaseAgent(BaseModel):
"""The sub-agents of this agent."""
before_agent_callback: Optional[BeforeAgentCallback] = None
"""Callback signature that is invoked before the agent run.
"""Callback or list of callbacks to be invoked before the agent run.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
callback_context: MUST be named 'callback_context' (enforced).
@@ -96,7 +107,10 @@ class BaseAgent(BaseModel):
provided content will be returned to user.
"""
after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run.
"""Callback or list of callbacks to be invoked after the agent run.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
callback_context: MUST be named 'callback_context' (enforced).
@@ -236,6 +250,30 @@ class BaseAgent(BaseModel):
invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context
@property
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_agent_callback:
return []
if isinstance(self.before_agent_callback, list):
return self.before_agent_callback
return [self.before_agent_callback]
@property
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_agent_callback:
return []
if isinstance(self.after_agent_callback, list):
return self.after_agent_callback
return [self.after_agent_callback]
async def __handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
@@ -246,27 +284,27 @@ class BaseAgent(BaseModel):
"""
ret_event = None
if not isinstance(self.before_agent_callback, Callable):
if not self.canonical_before_agent_callbacks:
return ret_event
callback_context = CallbackContext(ctx)
before_agent_callback_content = self.before_agent_callback(
callback_context=callback_context
)
if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content:
ret_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
content=before_agent_callback_content,
actions=callback_context._event_actions,
for callback in self.canonical_before_agent_callbacks:
before_agent_callback_content = callback(
callback_context=callback_context
)
ctx.end_invocation = True
return ret_event
if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content:
ret_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
content=before_agent_callback_content,
actions=callback_context._event_actions,
)
ctx.end_invocation = True
return ret_event
if callback_context.state.has_delta():
ret_event = Event(
@@ -288,18 +326,26 @@ class BaseAgent(BaseModel):
"""
ret_event = None
if not isinstance(self.after_agent_callback, Callable):
if not self.canonical_after_agent_callbacks:
return ret_event
callback_context = CallbackContext(invocation_context)
after_agent_callback_content = self.after_agent_callback(
callback_context=callback_context
)
if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content
for callback in self.canonical_after_agent_callbacks:
after_agent_callback_content = callback(callback_context=callback_context)
if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content
if after_agent_callback_content:
ret_event = Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
branch=invocation_context.branch,
content=after_agent_callback_content,
actions=callback_context._event_actions,
)
return ret_event
if after_agent_callback_content or callback_context.state.has_delta():
if callback_context.state.has_delta():
ret_event = Event(
invocation_id=invocation_context.invocation_id,
author=self.name,

View File

@@ -191,14 +191,19 @@ class BaseLlmFlow(ABC):
llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]:
"""Receive data from model and process events using BaseLlmConnection."""
def get_author(llm_response):
"""Get the author of the event.
When the model returns transcription, the author is "user". Otherwise, the
author is the agent.
"""
if llm_response and llm_response.content and llm_response.content.role == "user":
return "user"
if (
llm_response
and llm_response.content
and llm_response.content.role == 'user'
):
return 'user'
else:
return invocation_context.agent.name