From 2cbbf881353835ba1c321de865b0f53d1c4540e5 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Thu, 8 May 2025 17:37:30 -0700 Subject: [PATCH] Support chaining for tool callbacks (before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756526507 --- contributing/samples/hello_world/agent.py | 27 +++ .../samples/hello_world/asyncio_run.py | 4 +- src/google/adk/agents/llm_agent.py | 53 ++++- src/google/adk/flows/llm_flows/functions.py | 13 +- .../llm_flows/test_async_tool_callbacks.py | 202 +++++++++++++++++- 5 files changed, 282 insertions(+), 17 deletions(-) diff --git a/contributing/samples/hello_world/agent.py b/contributing/samples/hello_world/agent.py index e3f935c..0a45aba 100755 --- a/contributing/samples/hello_world/agent.py +++ b/contributing/samples/hello_world/agent.py @@ -117,6 +117,31 @@ def before_agent_cb3(callback_context): print('@before_agent_cb3') +def before_tool_cb1(tool, args, tool_context): + print('@before_tool_cb1') + + +def before_tool_cb2(tool, args, tool_context): + print('@before_tool_cb2') + + +def before_tool_cb3(tool, args, tool_context): + print('@before_tool_cb3') + + +def after_tool_cb1(tool, args, tool_context, tool_response): + print('@after_tool_cb1') + + +def after_tool_cb2(tool, args, tool_context, tool_response): + print('@after_tool_cb2') + return {'test': 'after_tool_cb2', 'response': tool_response} + + +def after_tool_cb3(tool, args, tool_context, tool_response): + print('@after_tool_cb3') + + root_agent = Agent( model='gemini-2.0-flash-exp', name='data_processing_agent', @@ -166,4 +191,6 @@ root_agent = Agent( after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3], before_model_callback=before_model_callback, after_model_callback=after_model_callback, + before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3], + after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3], ) diff --git a/contributing/samples/hello_world/asyncio_run.py b/contributing/samples/hello_world/asyncio_run.py index 1b58ef2..53768f5 100755 --- a/contributing/samples/hello_world/asyncio_run.py +++ b/contributing/samples/hello_world/asyncio_run.py @@ -83,7 +83,7 @@ async def main(): print('------------------------------------') await run_prompt(session_11, 'Hi') await run_prompt(session_11, 'Roll a die with 100 sides') - await run_prompt(session_11, 'Roll a die again.') + await run_prompt(session_11, 'Roll a die again with 100 sides.') await run_prompt(session_11, 'What numbers did I got?') await run_prompt_bytes(session_11, 'Hi bytes') print( @@ -130,7 +130,7 @@ def main_sync(): print('------------------------------------') run_prompt(session_11, 'Hi') run_prompt(session_11, 'Roll a die with 100 sides.') - run_prompt(session_11, 'Roll a die again.') + run_prompt(session_11, 'Roll a die again with 100 sides.') run_prompt(session_11, 'What numbers did I got?') end_time = time.time() print('------------------------------------') diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5c59adb..4d4ceae 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -67,15 +67,26 @@ AfterModelCallback: TypeAlias = Union[ list[_SingleAfterModelCallback], ] -BeforeToolCallback: TypeAlias = Callable[ +_SingleBeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], Union[Awaitable[Optional[dict]], Optional[dict]], ] -AfterToolCallback: TypeAlias = Callable[ + +BeforeToolCallback: TypeAlias = Union[ + _SingleBeforeToolCallback, + list[_SingleBeforeToolCallback], +] + +_SingleAfterToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext, dict], Union[Awaitable[Optional[dict]], Optional[dict]], ] +AfterToolCallback: TypeAlias = Union[ + _SingleAfterToolCallback, + list[_SingleAfterToolCallback], +] + InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] ToolUnion: TypeAlias = Union[Callable, BaseTool] @@ -214,7 +225,10 @@ class LlmAgent(BaseAgent): will be ignored and the provided content will be returned to user. """ before_tool_callback: Optional[BeforeToolCallback] = None - """Called before the tool is called. + """Callback or list of callbacks to be called before calling the tool. + + When a list of callbacks is provided, the callbacks will be called in the + order they are listed until a callback does not return None. Args: tool: The tool to be called. @@ -226,7 +240,10 @@ class LlmAgent(BaseAgent): the framework will skip calling the actual tool. """ after_tool_callback: Optional[AfterToolCallback] = None - """Called after the tool is called. + """Callback or list of callbacks to be called after calling the tool. + + When a list of callbacks is provided, the callbacks will be called in the + order they are listed until a callback does not return None. Args: tool: The tool to be called. @@ -329,6 +346,34 @@ class LlmAgent(BaseAgent): return self.after_model_callback return [self.after_model_callback] + @property + def canonical_before_tool_callbacks( + self, + ) -> list[BeforeToolCallback]: + """The resolved self.before_tool_callback field as a list of BeforeToolCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_tool_callback: + return [] + if isinstance(self.before_tool_callback, list): + return self.before_tool_callback + return [self.before_tool_callback] + + @property + def canonical_after_tool_callbacks( + self, + ) -> list[AfterToolCallback]: + """The resolved self.after_tool_callback field as a list of AfterToolCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_tool_callback: + return [] + if isinstance(self.after_tool_callback, list): + return self.after_tool_callback + return [self.after_tool_callback] + @property def _llm_flow(self) -> BaseLlmFlow: if ( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 182e427..25a2ab0 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -153,22 +153,22 @@ async def handle_function_calls_async( function_args = function_call.args or {} function_response: Optional[dict] = None - # before_tool_callback (sync or async) - if agent.before_tool_callback: - function_response = agent.before_tool_callback( + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( tool=tool, args=function_args, tool_context=tool_context ) if inspect.isawaitable(function_response): function_response = await function_response + if function_response: + break if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context ) - # after_tool_callback (sync or async) - if agent.after_tool_callback: - altered_function_response = agent.after_tool_callback( + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( tool=tool, args=function_args, tool_context=tool_context, @@ -178,6 +178,7 @@ async def handle_function_calls_async( altered_function_response = await altered_function_response if altered_function_response is not None: function_response = altered_function_response + break if tool.is_long_running: # Allow long running function to return None to not provide function response. diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py index 120755b..8ab66da 100644 --- a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -12,20 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional - -import pytest +from enum import Enum +from functools import partial +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from unittest import mock from google.adk.agents import Agent +from google.adk.agents.callback_context import CallbackContext +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import handle_function_calls_async from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext -from google.adk.flows.llm_flows.functions import handle_function_calls_async -from google.adk.events.event import Event from google.genai import types +import pytest from ... import utils +class CallbackType(Enum): + SYNC = 1 + ASYNC = 2 + + class AsyncBeforeToolCallback: def __init__(self, mock_response: Dict[str, Any]): @@ -107,3 +118,184 @@ async def test_async_after_tool_callback(): assert result_event is not None part = result_event.content.parts[0] assert part.function_response.response == mock_resp + + +def mock_async_before_cb_side_effect( + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ret_value: Optional[Dict[str, Any]] = None, +): + if ret_value: + return ret_value + return None + + +def mock_sync_before_cb_side_effect( + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ret_value: Optional[Dict[str, Any]] = None, +): + if ret_value: + return ret_value + return None + + +async def mock_async_after_cb_side_effect( + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ret_value: Optional[Dict[str, Any]] = None, +): + if ret_value: + return ret_value + return None + + +def mock_sync_after_cb_side_effect( + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ret_value: Optional[Dict[str, Any]] = None, +): + if ret_value: + return ret_value + return None + + +CALLBACK_PARAMS = [ + pytest.param( + [ + (None, CallbackType.SYNC), + ({"test": "callback_2_response"}, CallbackType.ASYNC), + ({"test": "callback_3_response"}, CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + {"test": "callback_2_response"}, + [1, 1, 0, 0], + id="middle_async_callback_returns", + ), + pytest.param( + [ + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + {"initial": "response"}, + [1, 1, 1, 1], + id="all_callbacks_return_none", + ), + pytest.param( + [ + ({"test": "callback_1_response"}, CallbackType.SYNC), + ({"test": "callback_2_response"}, CallbackType.ASYNC), + ], + {"test": "callback_1_response"}, + [1, 0], + id="first_sync_callback_returns", + ), +] + + +@pytest.mark.parametrize( + "callbacks, expected_response, expected_calls", + CALLBACK_PARAMS, +) +@pytest.mark.asyncio +async def test_before_tool_callbacks_chain( + callbacks: List[tuple[Optional[Dict[str, Any]], int]], + expected_response: Dict[str, Any], + expected_calls: List[int], +): + mock_before_cbs = [] + for response, callback_type in callbacks: + if callback_type == CallbackType.ASYNC: + mock_cb = mock.AsyncMock( + side_effect=partial( + mock_async_before_cb_side_effect, ret_value=response + ) + ) + else: + mock_cb = mock.Mock( + side_effect=partial( + mock_sync_before_cb_side_effect, ret_value=response + ) + ) + mock_before_cbs.append(mock_cb) + result_event = await invoke_tool_with_callbacks(before_cb=mock_before_cbs) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == expected_response + + # Assert that the callbacks were called the expected number of times + for i, mock_cb in enumerate(mock_before_cbs): + expected_calls_count = expected_calls[i] + if expected_calls_count == 1: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited_once() + else: + mock_cb.assert_called_once() + elif expected_calls_count == 0: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_not_awaited() + else: + mock_cb.assert_not_called() + else: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited(expected_calls_count) + else: + mock_cb.assert_called(expected_calls_count) + + +@pytest.mark.parametrize( + "callbacks, expected_response, expected_calls", + CALLBACK_PARAMS, +) +@pytest.mark.asyncio +async def test_after_tool_callbacks_chain( + callbacks: List[tuple[Optional[Dict[str, Any]], int]], + expected_response: Dict[str, Any], + expected_calls: List[int], +): + mock_after_cbs = [] + for response, callback_type in callbacks: + if callback_type == CallbackType.ASYNC: + mock_cb = mock.AsyncMock( + side_effect=partial( + mock_async_after_cb_side_effect, ret_value=response + ) + ) + else: + mock_cb = mock.Mock( + side_effect=partial( + mock_sync_after_cb_side_effect, ret_value=response + ) + ) + mock_after_cbs.append(mock_cb) + result_event = await invoke_tool_with_callbacks(after_cb=mock_after_cbs) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == expected_response + + # Assert that the callbacks were called the expected number of times + for i, mock_cb in enumerate(mock_after_cbs): + expected_calls_count = expected_calls[i] + if expected_calls_count == 1: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited_once() + else: + mock_cb.assert_called_once() + elif expected_calls_count == 0: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_not_awaited() + else: + mock_cb.assert_not_called() + else: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited(expected_calls_count) + else: + mock_cb.assert_called(expected_calls_count)