Support chaining for model callbacks

(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 755565583
This commit is contained in:
Selcuk Gun
2025-05-06 16:15:33 -07:00
committed by Copybara-Service
parent 794a70edcd
commit e4317c9eb7
5 changed files with 731 additions and 23 deletions

View File

@@ -47,15 +47,26 @@ from .readonly_context import ReadonlyContext
logger = logging.getLogger(__name__)
BeforeModelCallback: TypeAlias = Callable[
_SingleBeforeModelCallback: TypeAlias = Callable[
[CallbackContext, LlmRequest],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
]
AfterModelCallback: TypeAlias = Callable[
BeforeModelCallback: TypeAlias = Union[
_SingleBeforeModelCallback,
list[_SingleBeforeModelCallback],
]
_SingleAfterModelCallback: TypeAlias = Callable[
[CallbackContext, LlmResponse],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
]
AfterModelCallback: TypeAlias = Union[
_SingleAfterModelCallback,
list[_SingleAfterModelCallback],
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Union[Awaitable[Optional[dict]], Optional[dict]],
@@ -174,7 +185,11 @@ class LlmAgent(BaseAgent):
# Callbacks - Start
before_model_callback: Optional[BeforeModelCallback] = None
"""Called before calling the LLM.
"""Callback or list of callbacks to be called before calling the LLM.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
callback_context: CallbackContext,
llm_request: LlmRequest, The raw model request. Callback can mutate the
@@ -185,7 +200,10 @@ class LlmAgent(BaseAgent):
skipped and the provided content will be returned to user.
"""
after_model_callback: Optional[AfterModelCallback] = None
"""Called after calling LLM.
"""Callback or list of callbacks to be called after calling the LLM.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
callback_context: CallbackContext,
@@ -285,6 +303,32 @@ class LlmAgent(BaseAgent):
"""
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
@property
def canonical_before_model_callbacks(
self,
) -> list[_SingleBeforeModelCallback]:
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_model_callback:
return []
if isinstance(self.before_model_callback, list):
return self.before_model_callback
return [self.before_model_callback]
@property
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_model_callback:
return []
if isinstance(self.after_model_callback, list):
return self.after_model_callback
return [self.after_model_callback]
@property
def _llm_flow(self) -> BaseLlmFlow:
if (

View File

@@ -193,8 +193,9 @@ class BaseLlmFlow(ABC):
"""Receive data from model and process events using BaseLlmConnection."""
def get_author(llm_response):
"""Get the author of the event.
When the model returns transcription, the author is "user". Otherwise, the author is the agent.
When the model returns transcription, the author is "user". Otherwise, the
author is the agent.
"""
if llm_response and llm_response.content and llm_response.content.role == "user":
return "user"
@@ -509,20 +510,21 @@ class BaseLlmFlow(ABC):
if not isinstance(agent, LlmAgent):
return
if not agent.before_model_callback:
if not agent.canonical_before_model_callbacks:
return
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
before_model_callback_content = agent.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
if inspect.isawaitable(before_model_callback_content):
before_model_callback_content = await before_model_callback_content
return before_model_callback_content
for callback in agent.canonical_before_model_callbacks:
before_model_callback_content = callback(
callback_context=callback_context, llm_request=llm_request
)
if inspect.isawaitable(before_model_callback_content):
before_model_callback_content = await before_model_callback_content
if before_model_callback_content:
return before_model_callback_content
async def _handle_after_model_callback(
self,
@@ -536,20 +538,21 @@ class BaseLlmFlow(ABC):
if not isinstance(agent, LlmAgent):
return
if not agent.after_model_callback:
if not agent.canonical_after_model_callbacks:
return
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
after_model_callback_content = agent.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
if inspect.isawaitable(after_model_callback_content):
after_model_callback_content = await after_model_callback_content
return after_model_callback_content
for callback in agent.canonical_after_model_callbacks:
after_model_callback_content = callback(
callback_context=callback_context, llm_response=llm_response
)
if inspect.isawaitable(after_model_callback_content):
after_model_callback_content = await after_model_callback_content
if after_model_callback_content:
return after_model_callback_content
def _finalize_model_response_event(
self,