From 21736067f95cf45d74ecd8475908c5111592c2e7 Mon Sep 17 00:00:00 2001 From: Alankrit Verma Date: Mon, 28 Apr 2025 14:36:25 -0400 Subject: [PATCH 1/3] feat(llm_flows): support async before/after tool callbacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, callbacks were treated as purely synchronous, so passing an async coroutine caused “was never awaited” errors and Pydantic serialization failures. Now we detect awaitable return values from before_tool_callback and after_tool_callback, and `await` them if necessary. Fixes: #380 --- pyproject.toml | 1 + src/google/adk/flows/llm_flows/functions.py | 72 ++++++++---- .../llm_flows/test_async_tool_callbacks.py | 107 ++++++++++++++++++ 3 files changed, 157 insertions(+), 23 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_async_tool_callbacks.py diff --git a/pyproject.toml b/pyproject.toml index a9e9f2b..436db8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,3 +147,4 @@ line_length = 200 [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 7c5fcfb..25f9db3 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -152,27 +152,35 @@ async def handle_function_calls_async( # in python debugger. function_args = function_call.args or {} function_response = None - # Calls the tool if before_tool_callback does not exist or returns None. + # # Calls the tool if before_tool_callback does not exist or returns None. + # if agent.before_tool_callback: + # function_response = agent.before_tool_callback( + # tool=tool, args=function_args, tool_context=tool_context + # ) + # Short-circuit via before_tool_callback (sync *or* async) if agent.before_tool_callback: - function_response = agent.before_tool_callback( + _maybe = agent.before_tool_callback( tool=tool, args=function_args, tool_context=tool_context ) - + if inspect.isawaitable(_maybe): + _maybe = await _maybe + function_response = _maybe if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context ) - # Calls after_tool_callback if it exists. if agent.after_tool_callback: - new_response = agent.after_tool_callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, + _maybe2 = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, ) - if new_response: - function_response = new_response + if inspect.isawaitable(_maybe2): + _maybe2 = await _maybe2 + if _maybe2 is not None: + function_response = _maybe2 if tool.is_long_running: # Allow long running function to return None to not provide function response. @@ -223,11 +231,18 @@ async def handle_function_calls_live( # in python debugger. function_args = function_call.args or {} function_response = None - # Calls the tool if before_tool_callback does not exist or returns None. + # # Calls the tool if before_tool_callback does not exist or returns None. + # if agent.before_tool_callback: + # function_response = agent.before_tool_callback( + # tool, function_args, tool_context + # ) if agent.before_tool_callback: - function_response = agent.before_tool_callback( - tool, function_args, tool_context - ) + _maybe = agent.before_tool_callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(_maybe): + _maybe = await _maybe + function_response = _maybe if not function_response: function_response = await _process_function_live_helper( @@ -235,15 +250,26 @@ async def handle_function_calls_live( ) # Calls after_tool_callback if it exists. + # if agent.after_tool_callback: + # new_response = agent.after_tool_callback( + # tool, + # function_args, + # tool_context, + # function_response, + # ) + # if new_response: + # function_response = new_response if agent.after_tool_callback: - new_response = agent.after_tool_callback( - tool, - function_args, - tool_context, - function_response, - ) - if new_response: - function_response = new_response + _maybe2 = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(_maybe2): + _maybe2 = await _maybe2 + if _maybe2 is not None: + function_response = _maybe2 if tool.is_long_running: # Allow async function to return None to not provide function response. diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py new file mode 100644 index 0000000..ccccef8 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import pytest + +from google.adk.agents import Agent +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.events.event import Event +from google.genai import types + +from ... import utils + + +class AsyncBeforeToolCallback: + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ) -> Optional[Dict[str, Any]]: + return self.mock_response + + +class AsyncAfterToolCallback: + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + return self.mock_response + + +async def invoke_tool_with_callbacks( + before_cb=None, after_cb=None +) -> Optional[Event]: + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + before_tool_callback=before_cb, + after_tool_callback=after_cb, + ) + invocation_context = utils.create_invocation_context( + agent=agent, user_content="" + ) + # Build function call event + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + return await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + +@pytest.mark.asyncio +async def test_async_before_tool_callback(): + mock_resp = {"test": "before_tool_callback"} + before_cb = AsyncBeforeToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(before_cb=before_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp + + +@pytest.mark.asyncio +async def test_async_after_tool_callback(): + mock_resp = {"test": "after_tool_callback"} + after_cb = AsyncAfterToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(after_cb=after_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp From 08ac9a117e889af6611d64999b7b827ac81a2fcd Mon Sep 17 00:00:00 2001 From: Alankrit Verma Date: Mon, 28 Apr 2025 23:16:43 -0400 Subject: [PATCH 2/3] Refactor function callback handling and update type signatures - Simplify variable names in `functions.py`: always use `function_response` and `altered_function_response` - Update LlmAgent callback type aliases to support async: - Import `Awaitable` - Change `BeforeToolCallback` and `AfterToolCallback` signatures to return `Awaitable[Optional[dict]]` - Ensure `after_tool_callback` uses `await` when necessary --- src/google/adk/agents/llm_agent.py | 11 ++------ src/google/adk/flows/llm_flows/functions.py | 31 ++++++++++----------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index a140997..e6c7941 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -15,12 +15,7 @@ from __future__ import annotations import logging -from typing import Any -from typing import AsyncGenerator -from typing import Callable -from typing import Literal -from typing import Optional -from typing import Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union from google.genai import types from pydantic import BaseModel @@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[ ] BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], - Optional[dict], + Awaitable[Optional[dict]], ] AfterToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext, dict], - Optional[dict], + Awaitable[Optional[dict]], ] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 25f9db3..6d42c54 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -237,12 +237,11 @@ async def handle_function_calls_live( # tool, function_args, tool_context # ) if agent.before_tool_callback: - _maybe = agent.before_tool_callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(_maybe): - _maybe = await _maybe - function_response = _maybe + function_response = agent.before_tool_callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response if not function_response: function_response = await _process_function_live_helper( @@ -260,16 +259,16 @@ async def handle_function_calls_live( # if new_response: # function_response = new_response if agent.after_tool_callback: - _maybe2 = agent.after_tool_callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(_maybe2): - _maybe2 = await _maybe2 - if _maybe2 is not None: - function_response = _maybe2 + altered_function_response = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response if tool.is_long_running: # Allow async function to return None to not provide function response. From fcbf57466e38d53bf8c3c68c6471c7d449e6803b Mon Sep 17 00:00:00 2001 From: Alankrit Verma Date: Tue, 29 Apr 2025 09:02:09 -0400 Subject: [PATCH 3/3] refactor: update callback type signatures to support sync and async responses --- src/google/adk/agents/llm_agent.py | 4 +- src/google/adk/flows/llm_flows/functions.py | 39 +++--- .../llm_flows/test_async_tool_callbacks.py | 122 +++++++++--------- 3 files changed, 82 insertions(+), 83 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index e6c7941..67e2d31 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -57,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[ ] BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], - Awaitable[Optional[dict]], + Union[Awaitable[Optional[dict]], Optional[dict]], ] AfterToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext, dict], - Awaitable[Optional[dict]], + Union[Awaitable[Optional[dict]], Optional[dict]], ] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6d42c54..0e728b5 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -151,36 +151,33 @@ async def handle_function_calls_async( # do not use "args" as the variable name, because it is a reserved keyword # in python debugger. function_args = function_call.args or {} - function_response = None - # # Calls the tool if before_tool_callback does not exist or returns None. - # if agent.before_tool_callback: - # function_response = agent.before_tool_callback( - # tool=tool, args=function_args, tool_context=tool_context - # ) - # Short-circuit via before_tool_callback (sync *or* async) + function_response: Optional[dict] = None + + # before_tool_callback (sync or async) if agent.before_tool_callback: - _maybe = agent.before_tool_callback( + function_response = agent.before_tool_callback( tool=tool, args=function_args, tool_context=tool_context ) - if inspect.isawaitable(_maybe): - _maybe = await _maybe - function_response = _maybe + if inspect.isawaitable(function_response): + function_response = await function_response + if not function_response: function_response = await __call_tool_async( tool, args=function_args, tool_context=tool_context ) - # Calls after_tool_callback if it exists. + + # after_tool_callback (sync or async) if agent.after_tool_callback: - _maybe2 = agent.after_tool_callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, + altered_function_response = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, ) - if inspect.isawaitable(_maybe2): - _maybe2 = await _maybe2 - if _maybe2 is not None: - function_response = _maybe2 + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response if tool.is_long_running: # Allow long running function to return None to not provide function response. diff --git a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py index ccccef8..120755b 100644 --- a/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_async_tool_callbacks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - + from typing import Any, Dict, Optional import pytest @@ -27,81 +27,83 @@ from ... import utils class AsyncBeforeToolCallback: - def __init__(self, mock_response: Dict[str, Any]): - self.mock_response = mock_response - async def __call__( - self, - tool: FunctionTool, - args: Dict[str, Any], - tool_context: ToolContext, - ) -> Optional[Dict[str, Any]]: - return self.mock_response + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + ) -> Optional[Dict[str, Any]]: + return self.mock_response class AsyncAfterToolCallback: - def __init__(self, mock_response: Dict[str, Any]): - self.mock_response = mock_response - async def __call__( - self, - tool: FunctionTool, - args: Dict[str, Any], - tool_context: ToolContext, - tool_response: Dict[str, Any], - ) -> Optional[Dict[str, Any]]: - return self.mock_response + def __init__(self, mock_response: Dict[str, Any]): + self.mock_response = mock_response + + async def __call__( + self, + tool: FunctionTool, + args: Dict[str, Any], + tool_context: ToolContext, + tool_response: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + return self.mock_response async def invoke_tool_with_callbacks( before_cb=None, after_cb=None ) -> Optional[Event]: - def simple_fn(**kwargs) -> Dict[str, Any]: - return {"initial": "response"} + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} - tool = FunctionTool(simple_fn) - model = utils.MockModel.create(responses=[]) - agent = Agent( - name="agent", - model=model, - tools=[tool], - before_tool_callback=before_cb, - after_tool_callback=after_cb, - ) - invocation_context = utils.create_invocation_context( - agent=agent, user_content="" - ) - # Build function call event - function_call = types.FunctionCall(name=tool.name, args={}) - content = types.Content(parts=[types.Part(function_call=function_call)]) - event = Event( - invocation_id=invocation_context.invocation_id, - author=agent.name, - content=content, - ) - tools_dict = {tool.name: tool} - return await handle_function_calls_async( - invocation_context, - event, - tools_dict, - ) + tool = FunctionTool(simple_fn) + model = utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + before_tool_callback=before_cb, + after_tool_callback=after_cb, + ) + invocation_context = utils.create_invocation_context( + agent=agent, user_content="" + ) + # Build function call event + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + return await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) @pytest.mark.asyncio async def test_async_before_tool_callback(): - mock_resp = {"test": "before_tool_callback"} - before_cb = AsyncBeforeToolCallback(mock_resp) - result_event = await invoke_tool_with_callbacks(before_cb=before_cb) - assert result_event is not None - part = result_event.content.parts[0] - assert part.function_response.response == mock_resp + mock_resp = {"test": "before_tool_callback"} + before_cb = AsyncBeforeToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(before_cb=before_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp @pytest.mark.asyncio async def test_async_after_tool_callback(): - mock_resp = {"test": "after_tool_callback"} - after_cb = AsyncAfterToolCallback(mock_resp) - result_event = await invoke_tool_with_callbacks(after_cb=after_cb) - assert result_event is not None - part = result_event.content.parts[0] - assert part.function_response.response == mock_resp + mock_resp = {"test": "after_tool_callback"} + after_cb = AsyncAfterToolCallback(mock_resp) + result_event = await invoke_tool_with_callbacks(after_cb=after_cb) + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_resp