Support chaining for tool callbacks

(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756526507
This commit is contained in:
Selcuk Gun
2025-05-08 17:37:30 -07:00
committed by Copybara-Service
parent 0299020cc4
commit 2cbbf88135
5 changed files with 282 additions and 17 deletions
+49 -4
View File
@@ -67,15 +67,26 @@ AfterModelCallback: TypeAlias = Union[
list[_SingleAfterModelCallback],
]
BeforeToolCallback: TypeAlias = Callable[
_SingleBeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
AfterToolCallback: TypeAlias = Callable[
BeforeToolCallback: TypeAlias = Union[
_SingleBeforeToolCallback,
list[_SingleBeforeToolCallback],
]
_SingleAfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
AfterToolCallback: TypeAlias = Union[
_SingleAfterToolCallback,
list[_SingleAfterToolCallback],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
ToolUnion: TypeAlias = Union[Callable, BaseTool]
@@ -214,7 +225,10 @@ class LlmAgent(BaseAgent):
will be ignored and the provided content will be returned to user.
"""
before_tool_callback: Optional[BeforeToolCallback] = None
"""Called before the tool is called.
"""Callback or list of callbacks to be called before calling the tool.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
tool: The tool to be called.
@@ -226,7 +240,10 @@ class LlmAgent(BaseAgent):
the framework will skip calling the actual tool.
"""
after_tool_callback: Optional[AfterToolCallback] = None
"""Called after the tool is called.
"""Callback or list of callbacks to be called after calling the tool.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
tool: The tool to be called.
@@ -329,6 +346,34 @@ class LlmAgent(BaseAgent):
return self.after_model_callback
return [self.after_model_callback]
@property
def canonical_before_tool_callbacks(
self,
) -> list[BeforeToolCallback]:
"""The resolved self.before_tool_callback field as a list of BeforeToolCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_tool_callback:
return []
if isinstance(self.before_tool_callback, list):
return self.before_tool_callback
return [self.before_tool_callback]
@property
def canonical_after_tool_callbacks(
self,
) -> list[AfterToolCallback]:
"""The resolved self.after_tool_callback field as a list of AfterToolCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_tool_callback:
return []
if isinstance(self.after_tool_callback, list):
return self.after_tool_callback
return [self.after_tool_callback]
@property
def _llm_flow(self) -> BaseLlmFlow:
if (