From e4317c9eb7b85981b49d8179151ac135cf87f286 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Tue, 6 May 2025 16:15:33 -0700 Subject: [PATCH] Support chaining for model callbacks (before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 755565583 --- src/google/adk/agents/llm_agent.py | 54 +++- .../adk/flows/llm_flows/base_llm_flow.py | 39 +-- .../unittests/agents/test_agent_callbacks.py | 209 +++++++++++++++ .../agents/test_model_callback_chain.py | 242 ++++++++++++++++++ .../unittests/agents/test_model_callbacks.py | 210 +++++++++++++++ 5 files changed, 731 insertions(+), 23 deletions(-) create mode 100644 tests/unittests/agents/test_agent_callbacks.py create mode 100644 tests/unittests/agents/test_model_callback_chain.py create mode 100644 tests/unittests/agents/test_model_callbacks.py diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 7bde529..5c59adb 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -47,15 +47,26 @@ from .readonly_context import ReadonlyContext logger = logging.getLogger(__name__) - -BeforeModelCallback: TypeAlias = Callable[ +_SingleBeforeModelCallback: TypeAlias = Callable[ [CallbackContext, LlmRequest], Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]], ] -AfterModelCallback: TypeAlias = Callable[ + +BeforeModelCallback: TypeAlias = Union[ + _SingleBeforeModelCallback, + list[_SingleBeforeModelCallback], +] + +_SingleAfterModelCallback: TypeAlias = Callable[ [CallbackContext, LlmResponse], Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]], ] + +AfterModelCallback: TypeAlias = Union[ + _SingleAfterModelCallback, + list[_SingleAfterModelCallback], +] + BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], Union[Awaitable[Optional[dict]], Optional[dict]], @@ -174,7 +185,11 @@ class LlmAgent(BaseAgent): # Callbacks - Start before_model_callback: Optional[BeforeModelCallback] = None - """Called before calling the LLM. + """Callback or list of callbacks to be called before calling the LLM. + + When a list of callbacks is provided, the callbacks will be called in the + order they are listed until a callback does not return None. + Args: callback_context: CallbackContext, llm_request: LlmRequest, The raw model request. Callback can mutate the @@ -185,7 +200,10 @@ class LlmAgent(BaseAgent): skipped and the provided content will be returned to user. """ after_model_callback: Optional[AfterModelCallback] = None - """Called after calling LLM. + """Callback or list of callbacks to be called after calling the LLM. + + When a list of callbacks is provided, the callbacks will be called in the + order they are listed until a callback does not return None. Args: callback_context: CallbackContext, @@ -285,6 +303,32 @@ class LlmAgent(BaseAgent): """ return [_convert_tool_union_to_tool(tool) for tool in self.tools] + @property + def canonical_before_model_callbacks( + self, + ) -> list[_SingleBeforeModelCallback]: + """The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_model_callback: + return [] + if isinstance(self.before_model_callback, list): + return self.before_model_callback + return [self.before_model_callback] + + @property + def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]: + """The resolved self.after_model_callback field as a list of _SingleAfterModelCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_model_callback: + return [] + if isinstance(self.after_model_callback, list): + return self.after_model_callback + return [self.after_model_callback] + @property def _llm_flow(self) -> BaseLlmFlow: if ( diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index d1105e3..3a108b4 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -193,8 +193,9 @@ class BaseLlmFlow(ABC): """Receive data from model and process events using BaseLlmConnection.""" def get_author(llm_response): """Get the author of the event. - - When the model returns transcription, the author is "user". Otherwise, the author is the agent. + + When the model returns transcription, the author is "user". Otherwise, the + author is the agent. """ if llm_response and llm_response.content and llm_response.content.role == "user": return "user" @@ -509,20 +510,21 @@ class BaseLlmFlow(ABC): if not isinstance(agent, LlmAgent): return - if not agent.before_model_callback: + if not agent.canonical_before_model_callbacks: return callback_context = CallbackContext( invocation_context, event_actions=model_response_event.actions ) - before_model_callback_content = agent.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - if inspect.isawaitable(before_model_callback_content): - before_model_callback_content = await before_model_callback_content - - return before_model_callback_content + for callback in agent.canonical_before_model_callbacks: + before_model_callback_content = callback( + callback_context=callback_context, llm_request=llm_request + ) + if inspect.isawaitable(before_model_callback_content): + before_model_callback_content = await before_model_callback_content + if before_model_callback_content: + return before_model_callback_content async def _handle_after_model_callback( self, @@ -536,20 +538,21 @@ class BaseLlmFlow(ABC): if not isinstance(agent, LlmAgent): return - if not agent.after_model_callback: + if not agent.canonical_after_model_callbacks: return callback_context = CallbackContext( invocation_context, event_actions=model_response_event.actions ) - after_model_callback_content = agent.after_model_callback( - callback_context=callback_context, llm_response=llm_response - ) - if inspect.isawaitable(after_model_callback_content): - after_model_callback_content = await after_model_callback_content - - return after_model_callback_content + for callback in agent.canonical_after_model_callbacks: + after_model_callback_content = callback( + callback_context=callback_context, llm_response=llm_response + ) + if inspect.isawaitable(after_model_callback_content): + after_model_callback_content = await after_model_callback_content + if after_model_callback_content: + return after_model_callback_content def _finalize_model_response_event( self, diff --git a/tests/unittests/agents/test_agent_callbacks.py b/tests/unittests/agents/test_agent_callbacks.py new file mode 100644 index 0000000..c573557 --- /dev/null +++ b/tests/unittests/agents/test_agent_callbacks.py @@ -0,0 +1,209 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.llm_agent import Agent +from google.adk.models import LlmRequest +from google.adk.models import LlmResponse +from google.genai import types +from google.genai import types +from pydantic import BaseModel +import pytest + +from .. import utils + + +class MockAgentCallback(BaseModel): + mock_response: str + + def __call__( + self, + callback_context: CallbackContext, + ) -> types.Content: + return types.Content(parts=[types.Part(text=self.mock_response)]) + + +class MockAsyncAgentCallback(BaseModel): + mock_response: str + + async def __call__( + self, + callback_context: CallbackContext, + ) -> types.Content: + return types.Content(parts=[types.Part(text=self.mock_response)]) + + +def noop_callback(**kwargs) -> Optional[LlmResponse]: + pass + + +async def async_noop_callback(**kwargs) -> Optional[LlmResponse]: + pass + + +@pytest.mark.asyncio +async def test_before_agent_callback(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_agent_callback=MockAgentCallback( + mock_response='before_agent_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'before_agent_callback'), + ] + + +@pytest.mark.asyncio +async def test_after_agent_callback(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + after_agent_callback=MockAgentCallback( + mock_response='after_agent_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'agent_response'), + ('root_agent', 'after_agent_callback'), + ] + + +@pytest.mark.asyncio +async def test_before_agent_callback_noop(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_agent_callback=noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'agent_response'), + ] + + +@pytest.mark.asyncio +async def test_after_agent_callback_noop(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_agent_callback=noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'agent_response'), + ] + + +@pytest.mark.asyncio +async def test_async_before_agent_callback(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_agent_callback=MockAsyncAgentCallback( + mock_response='async_before_agent_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'async_before_agent_callback'), + ] + + +@pytest.mark.asyncio +async def test_async_after_agent_callback(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + after_agent_callback=MockAsyncAgentCallback( + mock_response='async_after_agent_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'agent_response'), + ('root_agent', 'async_after_agent_callback'), + ] + + +@pytest.mark.asyncio +async def test_async_before_agent_callback_noop(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_agent_callback=async_noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'agent_response'), + ] + + +@pytest.mark.asyncio +async def test_async_after_agent_callback_noop(): + responses = ['agent_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_agent_callback=async_noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'agent_response'), + ] diff --git a/tests/unittests/agents/test_model_callback_chain.py b/tests/unittests/agents/test_model_callback_chain.py new file mode 100644 index 0000000..3457e1d --- /dev/null +++ b/tests/unittests/agents/test_model_callback_chain.py @@ -0,0 +1,242 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from functools import partial +from typing import Any +from typing import List +from typing import Optional +from unittest import mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.llm_agent import Agent +from google.adk.models import LlmRequest +from google.adk.models import LlmResponse +from google.genai import types +from pydantic import BaseModel +import pytest + +from .. import utils + + +class CallbackType(Enum): + SYNC = 1 + ASYNC = 2 + + +async def mock_async_before_cb_side_effect( + callback_context: CallbackContext, + llm_request: LlmRequest, + ret_value=None, +): + if ret_value: + return LlmResponse( + content=utils.ModelContent([types.Part.from_text(text=ret_value)]) + ) + return None + + +def mock_sync_before_cb_side_effect( + callback_context: CallbackContext, + llm_request: LlmRequest, + ret_value=None, +): + if ret_value: + return LlmResponse( + content=utils.ModelContent([types.Part.from_text(text=ret_value)]) + ) + return None + + +async def mock_async_after_cb_side_effect( + callback_context: CallbackContext, + llm_response: LlmResponse, + ret_value=None, +): + if ret_value: + return LlmResponse( + content=utils.ModelContent([types.Part.from_text(text=ret_value)]) + ) + return None + + +def mock_sync_after_cb_side_effect( + callback_context: CallbackContext, + llm_response: LlmResponse, + ret_value=None, +): + if ret_value: + return LlmResponse( + content=utils.ModelContent([types.Part.from_text(text=ret_value)]) + ) + return None + + +CALLBACK_PARAMS = [ + pytest.param( + [ + (None, CallbackType.SYNC), + ("callback_2_response", CallbackType.ASYNC), + ("callback_3_response", CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + "callback_2_response", + [1, 1, 0, 0], + id="middle_async_callback_returns", + ), + pytest.param( + [ + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + "model_response", + [1, 1, 1, 1], + id="all_callbacks_return_none", + ), + pytest.param( + [ + ("callback_1_response", CallbackType.SYNC), + ("callback_2_response", CallbackType.ASYNC), + ], + "callback_1_response", + [1, 0], + id="first_sync_callback_returns", + ), +] + + +@pytest.mark.parametrize( + "callbacks, expected_response, expected_calls", + CALLBACK_PARAMS, +) +@pytest.mark.asyncio +async def test_before_model_callbacks_chain( + callbacks: List[tuple[str, int]], + expected_response: str, + expected_calls: List[int], +): + responses = ["model_response"] + mock_model = utils.MockModel.create(responses=responses) + + mock_cbs = [] + for response, callback_type in callbacks: + + if callback_type == CallbackType.ASYNC: + mock_cb = mock.AsyncMock( + side_effect=partial( + mock_async_before_cb_side_effect, ret_value=response + ) + ) + else: + mock_cb = mock.Mock( + side_effect=partial( + mock_sync_before_cb_side_effect, ret_value=response + ) + ) + mock_cbs.append(mock_cb) + # Create agent with multiple callbacks + agent = Agent( + name="root_agent", + model=mock_model, + before_model_callback=[mock_cb for mock_cb in mock_cbs], + ) + + runner = utils.TestInMemoryRunner(agent) + result = await runner.run_async_with_new_session("test") + assert utils.simplify_events(result) == [ + ("root_agent", expected_response), + ] + + # Assert that the callbacks were called the expected number of times + for i, mock_cb in enumerate(mock_cbs): + expected_calls_count = expected_calls[i] + if expected_calls_count == 1: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited_once() + else: + mock_cb.assert_called_once() + elif expected_calls_count == 0: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_not_awaited() + else: + mock_cb.assert_not_called() + else: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited(expected_calls_count) + else: + mock_cb.assert_called(expected_calls_count) + + +@pytest.mark.parametrize( + "callbacks, expected_response, expected_calls", + CALLBACK_PARAMS, +) +@pytest.mark.asyncio +async def test_after_model_callbacks_chain( + callbacks: List[tuple[str, int]], + expected_response: str, + expected_calls: List[int], +): + responses = ["model_response"] + mock_model = utils.MockModel.create(responses=responses) + + mock_cbs = [] + for response, callback_type in callbacks: + + if callback_type == CallbackType.ASYNC: + mock_cb = mock.AsyncMock( + side_effect=partial( + mock_async_after_cb_side_effect, ret_value=response + ) + ) + else: + mock_cb = mock.Mock( + side_effect=partial( + mock_sync_after_cb_side_effect, ret_value=response + ) + ) + mock_cbs.append(mock_cb) + # Create agent with multiple callbacks + agent = Agent( + name="root_agent", + model=mock_model, + after_model_callback=[mock_cb for mock_cb in mock_cbs], + ) + + runner = utils.TestInMemoryRunner(agent) + result = await runner.run_async_with_new_session("test") + assert utils.simplify_events(result) == [ + ("root_agent", expected_response), + ] + + # Assert that the callbacks were called the expected number of times + for i, mock_cb in enumerate(mock_cbs): + expected_calls_count = expected_calls[i] + if expected_calls_count == 1: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited_once() + else: + mock_cb.assert_called_once() + elif expected_calls_count == 0: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_not_awaited() + else: + mock_cb.assert_not_called() + else: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited(expected_calls_count) + else: + mock_cb.assert_called(expected_calls_count) diff --git a/tests/unittests/agents/test_model_callbacks.py b/tests/unittests/agents/test_model_callbacks.py new file mode 100644 index 0000000..99a606e --- /dev/null +++ b/tests/unittests/agents/test_model_callbacks.py @@ -0,0 +1,210 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.llm_agent import Agent +from google.adk.models import LlmRequest +from google.adk.models import LlmResponse +from google.genai import types +from pydantic import BaseModel +import pytest + +from .. import utils + + +class MockBeforeModelCallback(BaseModel): + mock_response: str + + def __call__( + self, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> LlmResponse: + return LlmResponse( + content=utils.ModelContent( + [types.Part.from_text(text=self.mock_response)] + ) + ) + + +class MockAfterModelCallback(BaseModel): + mock_response: str + + def __call__( + self, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> LlmResponse: + return LlmResponse( + content=utils.ModelContent( + [types.Part.from_text(text=self.mock_response)] + ) + ) + + +class MockAsyncBeforeModelCallback(BaseModel): + mock_response: str + + async def __call__( + self, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> LlmResponse: + return LlmResponse( + content=utils.ModelContent( + [types.Part.from_text(text=self.mock_response)] + ) + ) + + +class MockAsyncAfterModelCallback(BaseModel): + mock_response: str + + async def __call__( + self, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> LlmResponse: + return LlmResponse( + content=utils.ModelContent( + [types.Part.from_text(text=self.mock_response)] + ) + ) + + +def noop_callback(**kwargs) -> Optional[LlmResponse]: + pass + + +async def async_noop_callback(**kwargs) -> Optional[LlmResponse]: + pass + + +@pytest.mark.asyncio +async def test_before_model_callback(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_model_callback=MockBeforeModelCallback( + mock_response='before_model_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'before_model_callback'), + ] + + +@pytest.mark.asyncio +async def test_before_model_callback_noop(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_model_callback=noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'model_response'), + ] + + +@pytest.mark.asyncio +async def test_after_model_callback(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + after_model_callback=MockAfterModelCallback( + mock_response='after_model_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'after_model_callback'), + ] + + +@pytest.mark.asyncio +async def test_async_before_model_callback(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_model_callback=MockAsyncBeforeModelCallback( + mock_response='async_before_model_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'async_before_model_callback'), + ] + + +@pytest.mark.asyncio +async def test_async_before_model_callback_noop(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_model_callback=async_noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'model_response'), + ] + + +@pytest.mark.asyncio +async def test_async_after_model_callback(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + after_model_callback=MockAsyncAfterModelCallback( + mock_response='async_after_model_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'async_after_model_callback'), + ]