diff --git a/pyproject.toml b/pyproject.toml index a9e9f2b..436db8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,3 +147,4 @@ line_length = 200 [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 7c5fcfb..25f9db3 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -152,27 +152,35 @@ async def handle_function_calls_async( # in python debugger. function_args = function_call.args or {} function_response = None - # Calls the tool if before_tool_callback does not exist or returns None. + # # Calls the tool if before_tool_callback does not exist or returns None. + # if agent.before_tool_callback: + # function_response = agent.before_tool_callback( + # tool=tool, args=function_args, tool_context=tool_context + # ) + # Short-circuit via before_tool_callback (sync *or* async) if agent.before_tool_callback: - function_response = agent.before_tool_callback( + _maybe = agent.before_tool_callback( tool=tool, args=function_args, tool_context=tool_context ) - + if inspect.isawaitable(_maybe): + _maybe = await _maybe + function_response = _maybe if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context ) - # Calls after_tool_callback if it exists. if agent.after_tool_callback: - new_response = agent.after_tool_callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, + _maybe2 = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, ) - if new_response: - function_response = new_response + if inspect.isawaitable(_maybe2): + _maybe2 = await _maybe2 + if _maybe2 is not None: + function_response = _maybe2 if tool.is_long_running: # Allow long running function to return None to not provide function response. @@ -223,11 +231,18 @@ async def handle_function_calls_live( # in python debugger. function_args = function_call.args or {} function_response = None - # Calls the tool if before_tool_callback does not exist or returns None. + # # Calls the tool if before_tool_callback does not exist or returns None. + # if agent.before_tool_callback: + # function_response = agent.before_tool_callback( + # tool, function_args, tool_context + # ) if agent.before_tool_callback: - function_response = agent.before_tool_callback( - tool, function_args, tool_context - ) + _maybe = agent.before_tool_callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(_maybe): + _maybe = await _maybe + function_response = _maybe if not function_response: function_response = await _process_function_live_helper( @@ -235,15 +250,26 @@ async def handle_function_calls_live( ) # Calls after_tool_callback if it exists. + # if agent.after_tool_callback: + # new_response = agent.after_tool_callback( + # tool, + # function_args, + # tool_context, + # function_response, + # ) + # if new_response: + # function_response = new_response if agent.after_tool_callback: - new_response = agent.after_tool_callback( - tool, - function_args, - tool_context, - function_response, - ) - if new_response: - function_response = new_response + _maybe2 = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(_maybe2): + _maybe2 = await _maybe2 + if _maybe2 is not None: + function_response = _maybe2 if tool.is_long_running: # Allow async function to return None to not provide function response. diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py new file mode 100644 index 0000000..ccccef8 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import pytest + +from google.adk.agents import Agent +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.events.event import Event +from google.genai import types + +from ... import utils + + +class AsyncBeforeToolCallback: + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ) -> Optional[Dict[str, Any]]: + return self.mock_response + + +class AsyncAfterToolCallback: + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + return self.mock_response + + +async def invoke_tool_with_callbacks( + before_cb=None, after_cb=None +) -> Optional[Event]: + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + before_tool_callback=before_cb, + after_tool_callback=after_cb, + ) + invocation_context = utils.create_invocation_context( + agent=agent, user_content="" + ) + # Build function call event + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + return await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + +@pytest.mark.asyncio +async def test_async_before_tool_callback(): + mock_resp = {"test": "before_tool_callback"} + before_cb = AsyncBeforeToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(before_cb=before_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp + + +@pytest.mark.asyncio +async def test_async_after_tool_callback(): + mock_resp = {"test": "after_tool_callback"} + after_cb = AsyncAfterToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(after_cb=after_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp