From fcbf57466e38d53bf8c3c68c6471c7d449e6803b Mon Sep 17 00:00:00 2001 From: Alankrit Verma Date: Tue, 29 Apr 2025 09:02:09 -0400 Subject: [PATCH] refactor: update callback type signatures to support sync and async responses --- src/google/adk/agents/llm_agent.py | 4 +- src/google/adk/flows/llm_flows/functions.py | 39 +++--- .../llm_flows/test_async_tool_callbacks.py | 122 +++++++++--------- 3 files changed, 82 insertions(+), 83 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index e6c7941..67e2d31 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -57,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[ ] BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], - Awaitable[Optional[dict]], + Union[Awaitable[Optional[dict]], Optional[dict]], ] AfterToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext, dict], - Awaitable[Optional[dict]], + Union[Awaitable[Optional[dict]], Optional[dict]], ] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6d42c54..0e728b5 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -151,36 +151,33 @@ async def handle_function_calls_async( # do not use "args" as the variable name, because it is a reserved keyword # in python debugger. function_args = function_call.args or {} - function_response = None - # # Calls the tool if before_tool_callback does not exist or returns None. - # if agent.before_tool_callback: - # function_response = agent.before_tool_callback( - # tool=tool, args=function_args, tool_context=tool_context - # ) - # Short-circuit via before_tool_callback (sync *or* async) + function_response: Optional[dict] = None + + # before_tool_callback (sync or async) if agent.before_tool_callback: - _maybe = agent.before_tool_callback( + function_response = agent.before_tool_callback( tool=tool, args=function_args, tool_context=tool_context ) - if inspect.isawaitable(_maybe): - _maybe = await _maybe - function_response = _maybe + if inspect.isawaitable(function_response): + function_response = await function_response + if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context ) - # Calls after_tool_callback if it exists. + + # after_tool_callback (sync or async) if agent.after_tool_callback: - _maybe2 = agent.after_tool_callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, + altered_function_response = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, ) - if inspect.isawaitable(_maybe2): - _maybe2 = await _maybe2 - if _maybe2 is not None: - function_response = _maybe2 + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response if tool.is_long_running: # Allow long running function to return None to not provide function response. diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py index ccccef8..120755b 100644 --- a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - + from typing import Any, Dict, Optional import pytest @@ -27,81 +27,83 @@ from ... import utils class AsyncBeforeToolCallback: - def __init__(self, mock_response: Dict[str, Any]): - self.mock_response = mock_response - async def __call__( - self, - tool: FunctionTool, - args: Dict[str, Any], - tool_context: ToolContext, - ) -> Optional[Dict[str, Any]]: - return self.mock_response + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ) -> Optional[Dict[str, Any]]: + return self.mock_response class AsyncAfterToolCallback: - def __init__(self, mock_response: Dict[str, Any]): - self.mock_response = mock_response - async def __call__( - self, - tool: FunctionTool, - args: Dict[str, Any], - tool_context: ToolContext, - tool_response: Dict[str, Any], - ) -> Optional[Dict[str, Any]]: - return self.mock_response + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + return self.mock_response async def invoke_tool_with_callbacks( before_cb=None, after_cb=None ) -> Optional[Event]: - def simple_fn(**kwargs) -> Dict[str, Any]: - return {"initial": "response"} + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} - tool = FunctionTool(simple_fn) - model = utils.MockModel.create(responses=[]) - agent = Agent( - name="agent", - model=model, - tools=[tool], - before_tool_callback=before_cb, - after_tool_callback=after_cb, - ) - invocation_context = utils.create_invocation_context( - agent=agent, user_content="" - ) - # Build function call event - function_call = types.FunctionCall(name=tool.name, args={}) - content = types.Content(parts=[types.Part(function_call=function_call)]) - event = Event( - invocation_id=invocation_context.invocation_id, - author=agent.name, - content=content, - ) - tools_dict = {tool.name: tool} - return await handle_function_calls_async( - invocation_context, - event, - tools_dict, - ) + tool = FunctionTool(simple_fn) + model = utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + before_tool_callback=before_cb, + after_tool_callback=after_cb, + ) + invocation_context = utils.create_invocation_context( + agent=agent, user_content="" + ) + # Build function call event + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + return await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) @pytest.mark.asyncio async def test_async_before_tool_callback(): - mock_resp = {"test": "before_tool_callback"} - before_cb = AsyncBeforeToolCallback(mock_resp) - result_event = await invoke_tool_with_callbacks(before_cb=before_cb) - assert result_event is not None - part = result_event.content.parts[0] - assert part.function_response.response == mock_resp + mock_resp = {"test": "before_tool_callback"} + before_cb = AsyncBeforeToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(before_cb=before_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp @pytest.mark.asyncio async def test_async_after_tool_callback(): - mock_resp = {"test": "after_tool_callback"} - after_cb = AsyncAfterToolCallback(mock_resp) - result_event = await invoke_tool_with_callbacks(after_cb=after_cb) - assert result_event is not None - part = result_event.content.parts[0] - assert part.function_response.response == mock_resp + mock_resp = {"test": "after_tool_callback"} + after_cb = AsyncAfterToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(after_cb=after_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp