mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
Support chaining for tool callbacks
(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756526507
This commit is contained in:
committed by
Copybara-Service
parent
0299020cc4
commit
2cbbf88135
@@ -67,15 +67,26 @@ AfterModelCallback: TypeAlias = Union[
|
||||
list[_SingleAfterModelCallback],
|
||||
]
|
||||
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
_SingleBeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
]
|
||||
AfterToolCallback: TypeAlias = Callable[
|
||||
|
||||
BeforeToolCallback: TypeAlias = Union[
|
||||
_SingleBeforeToolCallback,
|
||||
list[_SingleBeforeToolCallback],
|
||||
]
|
||||
|
||||
_SingleAfterToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
]
|
||||
|
||||
AfterToolCallback: TypeAlias = Union[
|
||||
_SingleAfterToolCallback,
|
||||
list[_SingleAfterToolCallback],
|
||||
]
|
||||
|
||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||
|
||||
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
||||
@@ -214,7 +225,10 @@ class LlmAgent(BaseAgent):
|
||||
will be ignored and the provided content will be returned to user.
|
||||
"""
|
||||
before_tool_callback: Optional[BeforeToolCallback] = None
|
||||
"""Called before the tool is called.
|
||||
"""Callback or list of callbacks to be called before calling the tool.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
@@ -226,7 +240,10 @@ class LlmAgent(BaseAgent):
|
||||
the framework will skip calling the actual tool.
|
||||
"""
|
||||
after_tool_callback: Optional[AfterToolCallback] = None
|
||||
"""Called after the tool is called.
|
||||
"""Callback or list of callbacks to be called after calling the tool.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
@@ -329,6 +346,34 @@ class LlmAgent(BaseAgent):
|
||||
return self.after_model_callback
|
||||
return [self.after_model_callback]
|
||||
|
||||
@property
|
||||
def canonical_before_tool_callbacks(
|
||||
self,
|
||||
) -> list[BeforeToolCallback]:
|
||||
"""The resolved self.before_tool_callback field as a list of BeforeToolCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.before_tool_callback:
|
||||
return []
|
||||
if isinstance(self.before_tool_callback, list):
|
||||
return self.before_tool_callback
|
||||
return [self.before_tool_callback]
|
||||
|
||||
@property
|
||||
def canonical_after_tool_callbacks(
|
||||
self,
|
||||
) -> list[AfterToolCallback]:
|
||||
"""The resolved self.after_tool_callback field as a list of AfterToolCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.after_tool_callback:
|
||||
return []
|
||||
if isinstance(self.after_tool_callback, list):
|
||||
return self.after_tool_callback
|
||||
return [self.after_tool_callback]
|
||||
|
||||
@property
|
||||
def _llm_flow(self) -> BaseLlmFlow:
|
||||
if (
|
||||
|
||||
@@ -153,22 +153,22 @@ async def handle_function_calls_async(
|
||||
function_args = function_call.args or {}
|
||||
function_response: Optional[dict] = None
|
||||
|
||||
# before_tool_callback (sync or async)
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
for callback in agent.canonical_before_tool_callbacks:
|
||||
function_response = callback(
|
||||
tool=tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
if inspect.isawaitable(function_response):
|
||||
function_response = await function_response
|
||||
if function_response:
|
||||
break
|
||||
|
||||
if not function_response:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
# after_tool_callback (sync or async)
|
||||
if agent.after_tool_callback:
|
||||
altered_function_response = agent.after_tool_callback(
|
||||
for callback in agent.canonical_after_tool_callbacks:
|
||||
altered_function_response = callback(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
@@ -178,6 +178,7 @@ async def handle_function_calls_async(
|
||||
altered_function_response = await altered_function_response
|
||||
if altered_function_response is not None:
|
||||
function_response = altered_function_response
|
||||
break
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow long running function to return None to not provide function response.
|
||||
|
||||
Reference in New Issue
Block a user