Support chaining for model callbacks

(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 755565583
This commit is contained in:
Selcuk Gun
2025-05-06 16:15:33 -07:00
committed by Copybara-Service
parent 794a70edcd
commit e4317c9eb7
5 changed files with 731 additions and 23 deletions
+49 -5
View File
@@ -47,15 +47,26 @@ from .readonly_context import ReadonlyContext
logger = logging.getLogger(__name__)
BeforeModelCallback: TypeAlias = Callable[
_SingleBeforeModelCallback: TypeAlias = Callable[
[CallbackContext, LlmRequest],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
]
AfterModelCallback: TypeAlias = Callable[
BeforeModelCallback: TypeAlias = Union[
_SingleBeforeModelCallback,
list[_SingleBeforeModelCallback],
]
_SingleAfterModelCallback: TypeAlias = Callable[
[CallbackContext, LlmResponse],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
]
AfterModelCallback: TypeAlias = Union[
_SingleAfterModelCallback,
list[_SingleAfterModelCallback],
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Union[Awaitable[Optional[dict]], Optional[dict]],
@@ -174,7 +185,11 @@ class LlmAgent(BaseAgent):
# Callbacks - Start
before_model_callback: Optional[BeforeModelCallback] = None
"""Called before calling the LLM.
"""Callback or list of callbacks to be called before calling the LLM.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
callback_context: CallbackContext,
llm_request: LlmRequest, The raw model request. Callback can mutate the
@@ -185,7 +200,10 @@ class LlmAgent(BaseAgent):
skipped and the provided content will be returned to user.
"""
after_model_callback: Optional[AfterModelCallback] = None
"""Called after calling LLM.
"""Callback or list of callbacks to be called after calling the LLM.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
callback_context: CallbackContext,
@@ -285,6 +303,32 @@ class LlmAgent(BaseAgent):
"""
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
@property
def canonical_before_model_callbacks(
self,
) -> list[_SingleBeforeModelCallback]:
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_model_callback:
return []
if isinstance(self.before_model_callback, list):
return self.before_model_callback
return [self.before_model_callback]
@property
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_model_callback:
return []
if isinstance(self.after_model_callback, list):
return self.after_model_callback
return [self.after_model_callback]
@property
def _llm_flow(self) -> BaseLlmFlow:
if (