mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
Support async agent and model callbacks
PiperOrigin-RevId: 755542756
This commit is contained in:
committed by
Copybara-Service
parent
f96cdc675c
commit
794a70edcd
@@ -14,7 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import inspect
|
||||
from typing import Any, Awaitable, Union
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import final
|
||||
@@ -37,10 +38,15 @@ if TYPE_CHECKING:
|
||||
|
||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||
|
||||
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
||||
BeforeAgentCallback = Callable[
|
||||
[CallbackContext],
|
||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||
]
|
||||
|
||||
|
||||
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
||||
AfterAgentCallback = Callable[
|
||||
[CallbackContext],
|
||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||
]
|
||||
|
||||
|
||||
class BaseAgent(BaseModel):
|
||||
@@ -119,7 +125,7 @@ class BaseAgent(BaseModel):
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
|
||||
if event := self.__handle_before_agent_callback(ctx):
|
||||
if event := await self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
@@ -130,7 +136,7 @@ class BaseAgent(BaseModel):
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
if event := self.__handle_after_agent_callback(ctx):
|
||||
if event := await self.__handle_after_agent_callback(ctx):
|
||||
yield event
|
||||
|
||||
@final
|
||||
@@ -230,7 +236,7 @@ class BaseAgent(BaseModel):
|
||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||
return invocation_context
|
||||
|
||||
def __handle_before_agent_callback(
|
||||
async def __handle_before_agent_callback(
|
||||
self, ctx: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
"""Runs the before_agent_callback if it exists.
|
||||
@@ -248,6 +254,9 @@ class BaseAgent(BaseModel):
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if inspect.isawaitable(before_agent_callback_content):
|
||||
before_agent_callback_content = await before_agent_callback_content
|
||||
|
||||
if before_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
@@ -269,7 +278,7 @@ class BaseAgent(BaseModel):
|
||||
|
||||
return ret_event
|
||||
|
||||
def __handle_after_agent_callback(
|
||||
async def __handle_after_agent_callback(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
"""Runs the after_agent_callback if it exists.
|
||||
@@ -287,6 +296,9 @@ class BaseAgent(BaseModel):
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if inspect.isawaitable(after_agent_callback_content):
|
||||
after_agent_callback_content = await after_agent_callback_content
|
||||
|
||||
if after_agent_callback_content or callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
|
||||
@@ -49,11 +49,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BeforeModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmRequest], Optional[LlmResponse]
|
||||
[CallbackContext, LlmRequest],
|
||||
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||
]
|
||||
AfterModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmResponse],
|
||||
Optional[LlmResponse],
|
||||
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||
]
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
@@ -199,7 +200,7 @@ class BaseLlmFlow(ABC):
|
||||
return "user"
|
||||
else:
|
||||
return invocation_context.agent.name
|
||||
|
||||
|
||||
assert invocation_context.live_request_queue
|
||||
try:
|
||||
while True:
|
||||
@@ -447,7 +448,7 @@ class BaseLlmFlow(ABC):
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
# Runs before_model_callback if it exists.
|
||||
if response := self._handle_before_model_callback(
|
||||
if response := await self._handle_before_model_callback(
|
||||
invocation_context, llm_request, model_response_event
|
||||
):
|
||||
yield response
|
||||
@@ -460,7 +461,7 @@ class BaseLlmFlow(ABC):
|
||||
invocation_context.live_request_queue = LiveRequestQueue()
|
||||
async for llm_response in self.run_live(invocation_context):
|
||||
# Runs after_model_callback if it exists.
|
||||
if altered_llm_response := self._handle_after_model_callback(
|
||||
if altered_llm_response := await self._handle_after_model_callback(
|
||||
invocation_context, llm_response, model_response_event
|
||||
):
|
||||
llm_response = altered_llm_response
|
||||
@@ -489,14 +490,14 @@ class BaseLlmFlow(ABC):
|
||||
llm_response,
|
||||
)
|
||||
# Runs after_model_callback if it exists.
|
||||
if altered_llm_response := self._handle_after_model_callback(
|
||||
if altered_llm_response := await self._handle_after_model_callback(
|
||||
invocation_context, llm_response, model_response_event
|
||||
):
|
||||
llm_response = altered_llm_response
|
||||
|
||||
yield llm_response
|
||||
|
||||
def _handle_before_model_callback(
|
||||
async def _handle_before_model_callback(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
@@ -514,11 +515,16 @@ class BaseLlmFlow(ABC):
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
return agent.before_model_callback(
|
||||
before_model_callback_content = agent.before_model_callback(
|
||||
callback_context=callback_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
def _handle_after_model_callback(
|
||||
if inspect.isawaitable(before_model_callback_content):
|
||||
before_model_callback_content = await before_model_callback_content
|
||||
|
||||
return before_model_callback_content
|
||||
|
||||
async def _handle_after_model_callback(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_response: LlmResponse,
|
||||
@@ -536,10 +542,15 @@ class BaseLlmFlow(ABC):
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
return agent.after_model_callback(
|
||||
after_model_callback_content = agent.after_model_callback(
|
||||
callback_context=callback_context, llm_response=llm_response
|
||||
)
|
||||
|
||||
if inspect.isawaitable(after_model_callback_content):
|
||||
after_model_callback_content = await after_model_callback_content
|
||||
|
||||
return after_model_callback_content
|
||||
|
||||
def _finalize_model_response_event(
|
||||
self,
|
||||
llm_request: LlmRequest,
|
||||
|
||||
Reference in New Issue
Block a user