Support async agent and model callbacks

PiperOrigin-RevId: 755542756
This commit is contained in:
Selcuk Gun
2025-05-06 15:13:39 -07:00
committed by Copybara-Service
parent f96cdc675c
commit 794a70edcd
25 changed files with 371 additions and 117 deletions
+20 -8
View File
@@ -14,7 +14,8 @@
from __future__ import annotations
from typing import Any
import inspect
from typing import Any, Awaitable, Union
from typing import AsyncGenerator
from typing import Callable
from typing import final
@@ -37,10 +38,15 @@ if TYPE_CHECKING:
tracer = trace.get_tracer('gcp.vertex.agent')
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
BeforeAgentCallback = Callable[
[CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
]
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
AfterAgentCallback = Callable[
[CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
]
class BaseAgent(BaseModel):
@@ -119,7 +125,7 @@ class BaseAgent(BaseModel):
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
if event := self.__handle_before_agent_callback(ctx):
if event := await self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
@@ -130,7 +136,7 @@ class BaseAgent(BaseModel):
if ctx.end_invocation:
return
if event := self.__handle_after_agent_callback(ctx):
if event := await self.__handle_after_agent_callback(ctx):
yield event
@final
@@ -230,7 +236,7 @@ class BaseAgent(BaseModel):
invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context
def __handle_before_agent_callback(
async def __handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
"""Runs the before_agent_callback if it exists.
@@ -248,6 +254,9 @@ class BaseAgent(BaseModel):
callback_context=callback_context
)
if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content:
ret_event = Event(
invocation_id=ctx.invocation_id,
@@ -269,7 +278,7 @@ class BaseAgent(BaseModel):
return ret_event
def __handle_after_agent_callback(
async def __handle_after_agent_callback(
self, invocation_context: InvocationContext
) -> Optional[Event]:
"""Runs the after_agent_callback if it exists.
@@ -287,6 +296,9 @@ class BaseAgent(BaseModel):
callback_context=callback_context
)
if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content
if after_agent_callback_content or callback_context.state.has_delta():
ret_event = Event(
invocation_id=invocation_context.invocation_id,
+3 -2
View File
@@ -49,11 +49,12 @@ logger = logging.getLogger(__name__)
BeforeModelCallback: TypeAlias = Callable[
[CallbackContext, LlmRequest], Optional[LlmResponse]
[CallbackContext, LlmRequest],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
]
AfterModelCallback: TypeAlias = Callable[
[CallbackContext, LlmResponse],
Optional[LlmResponse],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],