From d45084f311fe4b6c77a2771665dab299593ce266 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Thu, 8 May 2025 10:08:51 -0700 Subject: [PATCH] Support chaining for agent callbacks (before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756359693 --- contributing/samples/hello_world/agent.py | 39 +++- src/google/adk/agents/base_agent.py | 106 ++++++--- .../adk/flows/llm_flows/base_llm_flow.py | 9 +- .../unittests/agents/test_agent_callbacks.py | 209 ----------------- tests/unittests/agents/test_base_agent.py | 220 +++++++++++++++++- .../unittests/agents/test_model_callbacks.py | 210 ----------------- 6 files changed, 339 insertions(+), 454 deletions(-) delete mode 100644 tests/unittests/agents/test_agent_callbacks.py delete mode 100644 tests/unittests/agents/test_model_callbacks.py diff --git a/contributing/samples/hello_world/agent.py b/contributing/samples/hello_world/agent.py index aceff7f..e3f935c 100755 --- a/contributing/samples/hello_world/agent.py +++ b/contributing/samples/hello_world/agent.py @@ -86,6 +86,37 @@ async def after_model_callback(callback_context, llm_response): return None +def after_agent_cb1(callback_context): + print('@after_agent_cb1') + + +def after_agent_cb2(callback_context): + print('@after_agent_cb2') + return types.Content( + parts=[ + types.Part( + text='(stopped) after_agent_cb2', + ), + ], + ) + + +def after_agent_cb3(callback_context): + print('@after_agent_cb3') + + +def before_agent_cb1(callback_context): + print('@before_agent_cb1') + + +def before_agent_cb2(callback_context): + print('@before_agent_cb2') + + +def before_agent_cb3(callback_context): + print('@before_agent_cb3') + + root_agent = Agent( model='gemini-2.0-flash-exp', name='data_processing_agent', @@ -127,8 +158,12 @@ root_agent = Agent( ), ] ), - before_agent_callback=before_agent_callback, - after_agent_callback=after_agent_callback, + before_agent_callback=[ + before_agent_cb1, + before_agent_cb2, + before_agent_cb3, + ], + after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3], before_model_callback=before_model_callback, after_model_callback=after_model_callback, ) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index ccf7e2b..18a5de4 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -15,12 +15,14 @@ from __future__ import annotations import inspect -from typing import Any, Awaitable, Union +from typing import Any from typing import AsyncGenerator +from typing import Awaitable from typing import Callable from typing import final from typing import Optional from typing import TYPE_CHECKING +from typing import Union from google.genai import types from opentelemetry import trace @@ -29,6 +31,7 @@ from pydantic import ConfigDict from pydantic import Field from pydantic import field_validator from typing_extensions import override +from typing_extensions import TypeAlias from ..events.event import Event from .callback_context import CallbackContext @@ -38,14 +41,19 @@ if TYPE_CHECKING: tracer = trace.get_tracer('gcp.vertex.agent') -BeforeAgentCallback = Callable[ +_SingleAgentCallback: TypeAlias = Callable[ [CallbackContext], Union[Awaitable[Optional[types.Content]], Optional[types.Content]], ] -AfterAgentCallback = Callable[ - [CallbackContext], - Union[Awaitable[Optional[types.Content]], Optional[types.Content]], +BeforeAgentCallback: TypeAlias = Union[ + _SingleAgentCallback, + list[_SingleAgentCallback], +] + +AfterAgentCallback: TypeAlias = Union[ + _SingleAgentCallback, + list[_SingleAgentCallback], ] @@ -85,7 +93,10 @@ class BaseAgent(BaseModel): """The sub-agents of this agent.""" before_agent_callback: Optional[BeforeAgentCallback] = None - """Callback signature that is invoked before the agent run. + """Callback or list of callbacks to be invoked before the agent run. + + When a list of callbacks is provided, the callbacks will be called in the + order they are listed until a callback does not return None. Args: callback_context: MUST be named 'callback_context' (enforced). @@ -96,7 +107,10 @@ class BaseAgent(BaseModel): provided content will be returned to user. """ after_agent_callback: Optional[AfterAgentCallback] = None - """Callback signature that is invoked after the agent run. + """Callback or list of callbacks to be invoked after the agent run. + + When a list of callbacks is provided, the callbacks will be called in the + order they are listed until a callback does not return None. Args: callback_context: MUST be named 'callback_context' (enforced). @@ -236,6 +250,30 @@ class BaseAgent(BaseModel): invocation_context.branch = f'{parent_context.branch}.{self.name}' return invocation_context + @property + def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.before_agent_callback: + return [] + if isinstance(self.before_agent_callback, list): + return self.before_agent_callback + return [self.before_agent_callback] + + @property + def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: + """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. + + This method is only for use by Agent Development Kit. + """ + if not self.after_agent_callback: + return [] + if isinstance(self.after_agent_callback, list): + return self.after_agent_callback + return [self.after_agent_callback] + async def __handle_before_agent_callback( self, ctx: InvocationContext ) -> Optional[Event]: @@ -246,27 +284,27 @@ class BaseAgent(BaseModel): """ ret_event = None - if not isinstance(self.before_agent_callback, Callable): + if not self.canonical_before_agent_callbacks: return ret_event callback_context = CallbackContext(ctx) - before_agent_callback_content = self.before_agent_callback( - callback_context=callback_context - ) - if inspect.isawaitable(before_agent_callback_content): - before_agent_callback_content = await before_agent_callback_content - - if before_agent_callback_content: - ret_event = Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=before_agent_callback_content, - actions=callback_context._event_actions, + for callback in self.canonical_before_agent_callbacks: + before_agent_callback_content = callback( + callback_context=callback_context ) - ctx.end_invocation = True - return ret_event + if inspect.isawaitable(before_agent_callback_content): + before_agent_callback_content = await before_agent_callback_content + if before_agent_callback_content: + ret_event = Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=before_agent_callback_content, + actions=callback_context._event_actions, + ) + ctx.end_invocation = True + return ret_event if callback_context.state.has_delta(): ret_event = Event( @@ -288,18 +326,26 @@ class BaseAgent(BaseModel): """ ret_event = None - if not isinstance(self.after_agent_callback, Callable): + if not self.canonical_after_agent_callbacks: return ret_event callback_context = CallbackContext(invocation_context) - after_agent_callback_content = self.after_agent_callback( - callback_context=callback_context - ) - if inspect.isawaitable(after_agent_callback_content): - after_agent_callback_content = await after_agent_callback_content + for callback in self.canonical_after_agent_callbacks: + after_agent_callback_content = callback(callback_context=callback_context) + if inspect.isawaitable(after_agent_callback_content): + after_agent_callback_content = await after_agent_callback_content + if after_agent_callback_content: + ret_event = Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + branch=invocation_context.branch, + content=after_agent_callback_content, + actions=callback_context._event_actions, + ) + return ret_event - if after_agent_callback_content or callback_context.state.has_delta(): + if callback_context.state.has_delta(): ret_event = Event( invocation_id=invocation_context.invocation_id, author=self.name, diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index e9daa1e..acf4d54 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -191,14 +191,19 @@ class BaseLlmFlow(ABC): llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: """Receive data from model and process events using BaseLlmConnection.""" + def get_author(llm_response): """Get the author of the event. When the model returns transcription, the author is "user". Otherwise, the author is the agent. """ - if llm_response and llm_response.content and llm_response.content.role == "user": - return "user" + if ( + llm_response + and llm_response.content + and llm_response.content.role == 'user' + ): + return 'user' else: return invocation_context.agent.name diff --git a/tests/unittests/agents/test_agent_callbacks.py b/tests/unittests/agents/test_agent_callbacks.py deleted file mode 100644 index c573557..0000000 --- a/tests/unittests/agents/test_agent_callbacks.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any -from typing import Optional - -from google.adk.agents.callback_context import CallbackContext -from google.adk.agents.llm_agent import Agent -from google.adk.models import LlmRequest -from google.adk.models import LlmResponse -from google.genai import types -from google.genai import types -from pydantic import BaseModel -import pytest - -from .. import utils - - -class MockAgentCallback(BaseModel): - mock_response: str - - def __call__( - self, - callback_context: CallbackContext, - ) -> types.Content: - return types.Content(parts=[types.Part(text=self.mock_response)]) - - -class MockAsyncAgentCallback(BaseModel): - mock_response: str - - async def __call__( - self, - callback_context: CallbackContext, - ) -> types.Content: - return types.Content(parts=[types.Part(text=self.mock_response)]) - - -def noop_callback(**kwargs) -> Optional[LlmResponse]: - pass - - -async def async_noop_callback(**kwargs) -> Optional[LlmResponse]: - pass - - -@pytest.mark.asyncio -async def test_before_agent_callback(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_agent_callback=MockAgentCallback( - mock_response='before_agent_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'before_agent_callback'), - ] - - -@pytest.mark.asyncio -async def test_after_agent_callback(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - after_agent_callback=MockAgentCallback( - mock_response='after_agent_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'agent_response'), - ('root_agent', 'after_agent_callback'), - ] - - -@pytest.mark.asyncio -async def test_before_agent_callback_noop(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_agent_callback=noop_callback, - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'agent_response'), - ] - - -@pytest.mark.asyncio -async def test_after_agent_callback_noop(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_agent_callback=noop_callback, - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'agent_response'), - ] - - -@pytest.mark.asyncio -async def test_async_before_agent_callback(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_agent_callback=MockAsyncAgentCallback( - mock_response='async_before_agent_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'async_before_agent_callback'), - ] - - -@pytest.mark.asyncio -async def test_async_after_agent_callback(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - after_agent_callback=MockAsyncAgentCallback( - mock_response='async_after_agent_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'agent_response'), - ('root_agent', 'async_after_agent_callback'), - ] - - -@pytest.mark.asyncio -async def test_async_before_agent_callback_noop(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_agent_callback=async_noop_callback, - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'agent_response'), - ] - - -@pytest.mark.asyncio -async def test_async_after_agent_callback_noop(): - responses = ['agent_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_agent_callback=async_noop_callback, - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'agent_response'), - ] diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 9733586..e162440 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -14,10 +14,13 @@ """Testings for the BaseAgent.""" +from enum import Enum +from functools import partial from typing import AsyncGenerator +from typing import List from typing import Optional from typing import Union - +from unittest import mock from google.adk.agents.base_agent import BaseAgent from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext @@ -27,6 +30,7 @@ from google.genai import types import pytest import pytest_mock from typing_extensions import override +from .. import utils def _before_agent_callback_noop(callback_context: CallbackContext) -> None: @@ -266,6 +270,220 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent( assert events[0].content.parts[0].text == 'agent run is bypassed.' +class CallbackType(Enum): + SYNC = 1 + ASYNC = 2 + + +async def mock_async_agent_cb_side_effect( + callback_context: CallbackContext, + ret_value=None, +): + if ret_value: + return types.Content(parts=[types.Part(text=ret_value)]) + return None + + +def mock_sync_agent_cb_side_effect( + callback_context: CallbackContext, + ret_value=None, +): + if ret_value: + return types.Content(parts=[types.Part(text=ret_value)]) + return None + + +BEFORE_AGENT_CALLBACK_PARAMS = [ + pytest.param( + [ + (None, CallbackType.SYNC), + ('callback_2_response', CallbackType.ASYNC), + ('callback_3_response', CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + ['callback_2_response'], + [1, 1, 0, 0], + id='middle_async_callback_returns', + ), + pytest.param( + [ + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + ['Hello, world!'], + [1, 1, 1, 1], + id='all_callbacks_return_none', + ), + pytest.param( + [ + ('callback_1_response', CallbackType.SYNC), + ('callback_2_response', CallbackType.ASYNC), + ], + ['callback_1_response'], + [1, 0], + id='first_sync_callback_returns', + ), +] + +AFTER_AGENT_CALLBACK_PARAMS = [ + pytest.param( + [ + (None, CallbackType.SYNC), + ('callback_2_response', CallbackType.ASYNC), + ('callback_3_response', CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + ['Hello, world!', 'callback_2_response'], + [1, 1, 0, 0], + id='middle_async_callback_returns', + ), + pytest.param( + [ + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + (None, CallbackType.SYNC), + (None, CallbackType.ASYNC), + ], + ['Hello, world!'], + [1, 1, 1, 1], + id='all_callbacks_return_none', + ), + pytest.param( + [ + ('callback_1_response', CallbackType.SYNC), + ('callback_2_response', CallbackType.ASYNC), + ], + ['Hello, world!', 'callback_1_response'], + [1, 0], + id='first_sync_callback_returns', + ), +] + + +@pytest.mark.parametrize( + 'callbacks, expected_responses, expected_calls', + BEFORE_AGENT_CALLBACK_PARAMS, +) +@pytest.mark.asyncio +async def test_before_agent_callbacks_chain( + callbacks: List[tuple[str, int]], + expected_responses: List[str], + expected_calls: List[int], + request: pytest.FixtureRequest, +): + mock_cbs = [] + for response, callback_type in callbacks: + + if callback_type == CallbackType.ASYNC: + mock_cb = mock.AsyncMock( + side_effect=partial( + mock_async_agent_cb_side_effect, ret_value=response + ) + ) + else: + mock_cb = mock.Mock( + side_effect=partial( + mock_sync_agent_cb_side_effect, ret_value=response + ) + ) + mock_cbs.append(mock_cb) + + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + before_agent_callback=[mock_cb for mock_cb in mock_cbs], + ) + parent_ctx = _create_parent_invocation_context( + request.function.__name__, agent + ) + result = [e async for e in agent.run_async(parent_ctx)] + assert utils.simplify_events(result) == [ + (f'{request.function.__name__}_test_agent', response) + for response in expected_responses + ] + + # Assert that the callbacks were called the expected number of times + for i, mock_cb in enumerate(mock_cbs): + expected_calls_count = expected_calls[i] + if expected_calls_count == 1: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited_once() + else: + mock_cb.assert_called_once() + elif expected_calls_count == 0: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_not_awaited() + else: + mock_cb.assert_not_called() + else: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited(expected_calls_count) + else: + mock_cb.assert_called(expected_calls_count) + + +@pytest.mark.parametrize( + 'callbacks, expected_responses, expected_calls', + AFTER_AGENT_CALLBACK_PARAMS, +) +@pytest.mark.asyncio +async def test_after_agent_callbacks_chain( + callbacks: List[tuple[str, int]], + expected_responses: List[str], + expected_calls: List[int], + request: pytest.FixtureRequest, +): + mock_cbs = [] + for response, callback_type in callbacks: + + if callback_type == CallbackType.ASYNC: + mock_cb = mock.AsyncMock( + side_effect=partial( + mock_async_agent_cb_side_effect, ret_value=response + ) + ) + else: + mock_cb = mock.Mock( + side_effect=partial( + mock_sync_agent_cb_side_effect, ret_value=response + ) + ) + mock_cbs.append(mock_cb) + + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + after_agent_callback=[mock_cb for mock_cb in mock_cbs], + ) + parent_ctx = _create_parent_invocation_context( + request.function.__name__, agent + ) + result = [e async for e in agent.run_async(parent_ctx)] + assert utils.simplify_events(result) == [ + (f'{request.function.__name__}_test_agent', response) + for response in expected_responses + ] + + # Assert that the callbacks were called the expected number of times + for i, mock_cb in enumerate(mock_cbs): + expected_calls_count = expected_calls[i] + if expected_calls_count == 1: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited_once() + else: + mock_cb.assert_called_once() + elif expected_calls_count == 0: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_not_awaited() + else: + mock_cb.assert_not_called() + else: + if isinstance(mock_cb, mock.AsyncMock): + mock_cb.assert_awaited(expected_calls_count) + else: + mock_cb.assert_called(expected_calls_count) + + @pytest.mark.asyncio async def test_run_async_after_agent_callback_noop( request: pytest.FixtureRequest, diff --git a/tests/unittests/agents/test_model_callbacks.py b/tests/unittests/agents/test_model_callbacks.py deleted file mode 100644 index 99a606e..0000000 --- a/tests/unittests/agents/test_model_callbacks.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any -from typing import Optional - -from google.adk.agents.callback_context import CallbackContext -from google.adk.agents.llm_agent import Agent -from google.adk.models import LlmRequest -from google.adk.models import LlmResponse -from google.genai import types -from pydantic import BaseModel -import pytest - -from .. import utils - - -class MockBeforeModelCallback(BaseModel): - mock_response: str - - def __call__( - self, - callback_context: CallbackContext, - llm_request: LlmRequest, - ) -> LlmResponse: - return LlmResponse( - content=utils.ModelContent( - [types.Part.from_text(text=self.mock_response)] - ) - ) - - -class MockAfterModelCallback(BaseModel): - mock_response: str - - def __call__( - self, - callback_context: CallbackContext, - llm_response: LlmResponse, - ) -> LlmResponse: - return LlmResponse( - content=utils.ModelContent( - [types.Part.from_text(text=self.mock_response)] - ) - ) - - -class MockAsyncBeforeModelCallback(BaseModel): - mock_response: str - - async def __call__( - self, - callback_context: CallbackContext, - llm_request: LlmRequest, - ) -> LlmResponse: - return LlmResponse( - content=utils.ModelContent( - [types.Part.from_text(text=self.mock_response)] - ) - ) - - -class MockAsyncAfterModelCallback(BaseModel): - mock_response: str - - async def __call__( - self, - callback_context: CallbackContext, - llm_response: LlmResponse, - ) -> LlmResponse: - return LlmResponse( - content=utils.ModelContent( - [types.Part.from_text(text=self.mock_response)] - ) - ) - - -def noop_callback(**kwargs) -> Optional[LlmResponse]: - pass - - -async def async_noop_callback(**kwargs) -> Optional[LlmResponse]: - pass - - -@pytest.mark.asyncio -async def test_before_model_callback(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_model_callback=MockBeforeModelCallback( - mock_response='before_model_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'before_model_callback'), - ] - - -@pytest.mark.asyncio -async def test_before_model_callback_noop(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_model_callback=noop_callback, - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'model_response'), - ] - - -@pytest.mark.asyncio -async def test_after_model_callback(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - after_model_callback=MockAfterModelCallback( - mock_response='after_model_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'after_model_callback'), - ] - - -@pytest.mark.asyncio -async def test_async_before_model_callback(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_model_callback=MockAsyncBeforeModelCallback( - mock_response='async_before_model_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'async_before_model_callback'), - ] - - -@pytest.mark.asyncio -async def test_async_before_model_callback_noop(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_model_callback=async_noop_callback, - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'model_response'), - ] - - -@pytest.mark.asyncio -async def test_async_after_model_callback(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - after_model_callback=MockAsyncAfterModelCallback( - mock_response='async_after_model_callback' - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'async_after_model_callback'), - ]