diff --git a/pyproject.toml b/pyproject.toml index a9e9f2b..436db8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,3 +147,4 @@ line_length = 200 [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index a140997..67e2d31 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -15,12 +15,7 @@ from __future__ import annotations import logging -from typing import Any -from typing import AsyncGenerator -from typing import Callable -from typing import Literal -from typing import Optional -from typing import Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union from google.genai import types from pydantic import BaseModel @@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[ ] BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], - Optional[dict], + Union[Awaitable[Optional[dict]], Optional[dict]], ] AfterToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext, dict], - Optional[dict], + Union[Awaitable[Optional[dict]], Optional[dict]], ] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 7c5fcfb..0e728b5 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -151,28 +151,33 @@ async def handle_function_calls_async( # do not use "args" as the variable name, because it is a reserved keyword # in python debugger. function_args = function_call.args or {} - function_response = None - # Calls the tool if before_tool_callback does not exist or returns None. + function_response: Optional[dict] = None + + # before_tool_callback (sync or async) if agent.before_tool_callback: function_response = agent.before_tool_callback( tool=tool, args=function_args, tool_context=tool_context ) + if inspect.isawaitable(function_response): + function_response = await function_response if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context ) - # Calls after_tool_callback if it exists. + # after_tool_callback (sync or async) if agent.after_tool_callback: - new_response = agent.after_tool_callback( + altered_function_response = agent.after_tool_callback( tool=tool, args=function_args, tool_context=tool_context, tool_response=function_response, ) - if new_response: - function_response = new_response + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response if tool.is_long_running: # Allow long running function to return None to not provide function response. @@ -223,11 +228,17 @@ async def handle_function_calls_live( # in python debugger. function_args = function_call.args or {} function_response = None - # Calls the tool if before_tool_callback does not exist or returns None. + # # Calls the tool if before_tool_callback does not exist or returns None. + # if agent.before_tool_callback: + # function_response = agent.before_tool_callback( + # tool, function_args, tool_context + # ) if agent.before_tool_callback: function_response = agent.before_tool_callback( - tool, function_args, tool_context + tool=tool, args=function_args, tool_context=tool_context ) + if inspect.isawaitable(function_response): + function_response = await function_response if not function_response: function_response = await _process_function_live_helper( @@ -235,15 +246,26 @@ async def handle_function_calls_live( ) # Calls after_tool_callback if it exists. + # if agent.after_tool_callback: + # new_response = agent.after_tool_callback( + # tool, + # function_args, + # tool_context, + # function_response, + # ) + # if new_response: + # function_response = new_response if agent.after_tool_callback: - new_response = agent.after_tool_callback( - tool, - function_args, - tool_context, - function_response, + altered_function_response = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, ) - if new_response: - function_response = new_response + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response if tool.is_long_running: # Allow async function to return None to not provide function response. diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py new file mode 100644 index 0000000..120755b --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import pytest + +from google.adk.agents import Agent +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.events.event import Event +from google.genai import types + +from ... import utils + + +class AsyncBeforeToolCallback: + + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ) -> Optional[Dict[str, Any]]: + return self.mock_response + + +class AsyncAfterToolCallback: + + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + return self.mock_response + + +async def invoke_tool_with_callbacks( + before_cb=None, after_cb=None +) -> Optional[Event]: + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + before_tool_callback=before_cb, + after_tool_callback=after_cb, + ) + invocation_context = utils.create_invocation_context( + agent=agent, user_content="" + ) + # Build function call event + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + return await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + +@pytest.mark.asyncio +async def test_async_before_tool_callback(): + mock_resp = {"test": "before_tool_callback"} + before_cb = AsyncBeforeToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(before_cb=before_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp + + +@pytest.mark.asyncio +async def test_async_after_tool_callback(): + mock_resp = {"test": "after_tool_callback"} + after_cb = AsyncAfterToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(after_cb=after_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp