Support chaining for tool callbacks

(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756526507
This commit is contained in:
Selcuk Gun
2025-05-08 17:37:30 -07:00
committed by Copybara-Service
parent 0299020cc4
commit 2cbbf88135
5 changed files with 282 additions and 17 deletions

View File

@@ -67,15 +67,26 @@ AfterModelCallback: TypeAlias = Union[
list[_SingleAfterModelCallback],
]
BeforeToolCallback: TypeAlias = Callable[
_SingleBeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
AfterToolCallback: TypeAlias = Callable[
BeforeToolCallback: TypeAlias = Union[
_SingleBeforeToolCallback,
list[_SingleBeforeToolCallback],
]
_SingleAfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
AfterToolCallback: TypeAlias = Union[
_SingleAfterToolCallback,
list[_SingleAfterToolCallback],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
ToolUnion: TypeAlias = Union[Callable, BaseTool]
@@ -214,7 +225,10 @@ class LlmAgent(BaseAgent):
will be ignored and the provided content will be returned to user.
"""
before_tool_callback: Optional[BeforeToolCallback] = None
"""Called before the tool is called.
"""Callback or list of callbacks to be called before calling the tool.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
tool: The tool to be called.
@@ -226,7 +240,10 @@ class LlmAgent(BaseAgent):
the framework will skip calling the actual tool.
"""
after_tool_callback: Optional[AfterToolCallback] = None
"""Called after the tool is called.
"""Callback or list of callbacks to be called after calling the tool.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args:
tool: The tool to be called.
@@ -329,6 +346,34 @@ class LlmAgent(BaseAgent):
return self.after_model_callback
return [self.after_model_callback]
@property
def canonical_before_tool_callbacks(
self,
) -> list[BeforeToolCallback]:
"""The resolved self.before_tool_callback field as a list of BeforeToolCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_tool_callback:
return []
if isinstance(self.before_tool_callback, list):
return self.before_tool_callback
return [self.before_tool_callback]
@property
def canonical_after_tool_callbacks(
self,
) -> list[AfterToolCallback]:
"""The resolved self.after_tool_callback field as a list of AfterToolCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_tool_callback:
return []
if isinstance(self.after_tool_callback, list):
return self.after_tool_callback
return [self.after_tool_callback]
@property
def _llm_flow(self) -> BaseLlmFlow:
if (

View File

@@ -153,22 +153,22 @@ async def handle_function_calls_async(
function_args = function_call.args or {}
function_response: Optional[dict] = None
# before_tool_callback (sync or async)
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
for callback in agent.canonical_before_tool_callbacks:
function_response = callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if function_response:
break
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
# after_tool_callback (sync or async)
if agent.after_tool_callback:
altered_function_response = agent.after_tool_callback(
for callback in agent.canonical_after_tool_callbacks:
altered_function_response = callback(
tool=tool,
args=function_args,
tool_context=tool_context,
@@ -178,6 +178,7 @@ async def handle_function_calls_async(
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
break
if tool.is_long_running:
# Allow long running function to return None to not provide function response.