mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
Support chaining for model callbacks
(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 755565583
This commit is contained in:
committed by
Copybara-Service
parent
794a70edcd
commit
e4317c9eb7
@@ -47,15 +47,26 @@ from .readonly_context import ReadonlyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BeforeModelCallback: TypeAlias = Callable[
|
||||
_SingleBeforeModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmRequest],
|
||||
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||
]
|
||||
AfterModelCallback: TypeAlias = Callable[
|
||||
|
||||
BeforeModelCallback: TypeAlias = Union[
|
||||
_SingleBeforeModelCallback,
|
||||
list[_SingleBeforeModelCallback],
|
||||
]
|
||||
|
||||
_SingleAfterModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmResponse],
|
||||
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||
]
|
||||
|
||||
AfterModelCallback: TypeAlias = Union[
|
||||
_SingleAfterModelCallback,
|
||||
list[_SingleAfterModelCallback],
|
||||
]
|
||||
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
@@ -174,7 +185,11 @@ class LlmAgent(BaseAgent):
|
||||
|
||||
# Callbacks - Start
|
||||
before_model_callback: Optional[BeforeModelCallback] = None
|
||||
"""Called before calling the LLM.
|
||||
"""Callback or list of callbacks to be called before calling the LLM.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest, The raw model request. Callback can mutate the
|
||||
@@ -185,7 +200,10 @@ class LlmAgent(BaseAgent):
|
||||
skipped and the provided content will be returned to user.
|
||||
"""
|
||||
after_model_callback: Optional[AfterModelCallback] = None
|
||||
"""Called after calling LLM.
|
||||
"""Callback or list of callbacks to be called after calling the LLM.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
callback_context: CallbackContext,
|
||||
@@ -285,6 +303,32 @@ class LlmAgent(BaseAgent):
|
||||
"""
|
||||
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
|
||||
|
||||
@property
|
||||
def canonical_before_model_callbacks(
|
||||
self,
|
||||
) -> list[_SingleBeforeModelCallback]:
|
||||
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.before_model_callback:
|
||||
return []
|
||||
if isinstance(self.before_model_callback, list):
|
||||
return self.before_model_callback
|
||||
return [self.before_model_callback]
|
||||
|
||||
@property
|
||||
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
|
||||
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.after_model_callback:
|
||||
return []
|
||||
if isinstance(self.after_model_callback, list):
|
||||
return self.after_model_callback
|
||||
return [self.after_model_callback]
|
||||
|
||||
@property
|
||||
def _llm_flow(self) -> BaseLlmFlow:
|
||||
if (
|
||||
|
||||
@@ -193,8 +193,9 @@ class BaseLlmFlow(ABC):
|
||||
"""Receive data from model and process events using BaseLlmConnection."""
|
||||
def get_author(llm_response):
|
||||
"""Get the author of the event.
|
||||
|
||||
When the model returns transcription, the author is "user". Otherwise, the author is the agent.
|
||||
|
||||
When the model returns transcription, the author is "user". Otherwise, the
|
||||
author is the agent.
|
||||
"""
|
||||
if llm_response and llm_response.content and llm_response.content.role == "user":
|
||||
return "user"
|
||||
@@ -509,20 +510,21 @@ class BaseLlmFlow(ABC):
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if not agent.before_model_callback:
|
||||
if not agent.canonical_before_model_callbacks:
|
||||
return
|
||||
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
before_model_callback_content = agent.before_model_callback(
|
||||
callback_context=callback_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
if inspect.isawaitable(before_model_callback_content):
|
||||
before_model_callback_content = await before_model_callback_content
|
||||
|
||||
return before_model_callback_content
|
||||
for callback in agent.canonical_before_model_callbacks:
|
||||
before_model_callback_content = callback(
|
||||
callback_context=callback_context, llm_request=llm_request
|
||||
)
|
||||
if inspect.isawaitable(before_model_callback_content):
|
||||
before_model_callback_content = await before_model_callback_content
|
||||
if before_model_callback_content:
|
||||
return before_model_callback_content
|
||||
|
||||
async def _handle_after_model_callback(
|
||||
self,
|
||||
@@ -536,20 +538,21 @@ class BaseLlmFlow(ABC):
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if not agent.after_model_callback:
|
||||
if not agent.canonical_after_model_callbacks:
|
||||
return
|
||||
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
after_model_callback_content = agent.after_model_callback(
|
||||
callback_context=callback_context, llm_response=llm_response
|
||||
)
|
||||
|
||||
if inspect.isawaitable(after_model_callback_content):
|
||||
after_model_callback_content = await after_model_callback_content
|
||||
|
||||
return after_model_callback_content
|
||||
for callback in agent.canonical_after_model_callbacks:
|
||||
after_model_callback_content = callback(
|
||||
callback_context=callback_context, llm_response=llm_response
|
||||
)
|
||||
if inspect.isawaitable(after_model_callback_content):
|
||||
after_model_callback_content = await after_model_callback_content
|
||||
if after_model_callback_content:
|
||||
return after_model_callback_content
|
||||
|
||||
def _finalize_model_response_event(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user