Support chaining for agent callbacks

(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756359693
This commit is contained in:
Selcuk Gun 2025-05-08 10:08:51 -07:00 committed by Copybara-Service
parent a61d20e3df
commit d45084f311
6 changed files with 339 additions and 454 deletions

View File

@ -86,6 +86,37 @@ async def after_model_callback(callback_context, llm_response):
return None return None
def after_agent_cb1(callback_context):
print('@after_agent_cb1')
def after_agent_cb2(callback_context):
print('@after_agent_cb2')
return types.Content(
parts=[
types.Part(
text='(stopped) after_agent_cb2',
),
],
)
def after_agent_cb3(callback_context):
print('@after_agent_cb3')
def before_agent_cb1(callback_context):
print('@before_agent_cb1')
def before_agent_cb2(callback_context):
print('@before_agent_cb2')
def before_agent_cb3(callback_context):
print('@before_agent_cb3')
root_agent = Agent( root_agent = Agent(
model='gemini-2.0-flash-exp', model='gemini-2.0-flash-exp',
name='data_processing_agent', name='data_processing_agent',
@ -127,8 +158,12 @@ root_agent = Agent(
), ),
] ]
), ),
before_agent_callback=before_agent_callback, before_agent_callback=[
after_agent_callback=after_agent_callback, before_agent_cb1,
before_agent_cb2,
before_agent_cb3,
],
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
before_model_callback=before_model_callback, before_model_callback=before_model_callback,
after_model_callback=after_model_callback, after_model_callback=after_model_callback,
) )

View File

@ -15,12 +15,14 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from typing import Any, Awaitable, Union from typing import Any
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Awaitable
from typing import Callable from typing import Callable
from typing import final from typing import final
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
from google.genai import types from google.genai import types
from opentelemetry import trace from opentelemetry import trace
@ -29,6 +31,7 @@ from pydantic import ConfigDict
from pydantic import Field from pydantic import Field
from pydantic import field_validator from pydantic import field_validator
from typing_extensions import override from typing_extensions import override
from typing_extensions import TypeAlias
from ..events.event import Event from ..events.event import Event
from .callback_context import CallbackContext from .callback_context import CallbackContext
@ -38,14 +41,19 @@ if TYPE_CHECKING:
tracer = trace.get_tracer('gcp.vertex.agent') tracer = trace.get_tracer('gcp.vertex.agent')
BeforeAgentCallback = Callable[ _SingleAgentCallback: TypeAlias = Callable[
[CallbackContext], [CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]], Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
] ]
AfterAgentCallback = Callable[ BeforeAgentCallback: TypeAlias = Union[
[CallbackContext], _SingleAgentCallback,
Union[Awaitable[Optional[types.Content]], Optional[types.Content]], list[_SingleAgentCallback],
]
AfterAgentCallback: TypeAlias = Union[
_SingleAgentCallback,
list[_SingleAgentCallback],
] ]
@ -85,7 +93,10 @@ class BaseAgent(BaseModel):
"""The sub-agents of this agent.""" """The sub-agents of this agent."""
before_agent_callback: Optional[BeforeAgentCallback] = None before_agent_callback: Optional[BeforeAgentCallback] = None
"""Callback signature that is invoked before the agent run. """Callback or list of callbacks to be invoked before the agent run.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args: Args:
callback_context: MUST be named 'callback_context' (enforced). callback_context: MUST be named 'callback_context' (enforced).
@ -96,7 +107,10 @@ class BaseAgent(BaseModel):
provided content will be returned to user. provided content will be returned to user.
""" """
after_agent_callback: Optional[AfterAgentCallback] = None after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run. """Callback or list of callbacks to be invoked after the agent run.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args: Args:
callback_context: MUST be named 'callback_context' (enforced). callback_context: MUST be named 'callback_context' (enforced).
@ -236,6 +250,30 @@ class BaseAgent(BaseModel):
invocation_context.branch = f'{parent_context.branch}.{self.name}' invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context return invocation_context
@property
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_agent_callback:
return []
if isinstance(self.before_agent_callback, list):
return self.before_agent_callback
return [self.before_agent_callback]
@property
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_agent_callback:
return []
if isinstance(self.after_agent_callback, list):
return self.after_agent_callback
return [self.after_agent_callback]
async def __handle_before_agent_callback( async def __handle_before_agent_callback(
self, ctx: InvocationContext self, ctx: InvocationContext
) -> Optional[Event]: ) -> Optional[Event]:
@ -246,17 +284,17 @@ class BaseAgent(BaseModel):
""" """
ret_event = None ret_event = None
if not isinstance(self.before_agent_callback, Callable): if not self.canonical_before_agent_callbacks:
return ret_event return ret_event
callback_context = CallbackContext(ctx) callback_context = CallbackContext(ctx)
before_agent_callback_content = self.before_agent_callback(
for callback in self.canonical_before_agent_callbacks:
before_agent_callback_content = callback(
callback_context=callback_context callback_context=callback_context
) )
if inspect.isawaitable(before_agent_callback_content): if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content: if before_agent_callback_content:
ret_event = Event( ret_event = Event(
invocation_id=ctx.invocation_id, invocation_id=ctx.invocation_id,
@ -288,18 +326,26 @@ class BaseAgent(BaseModel):
""" """
ret_event = None ret_event = None
if not isinstance(self.after_agent_callback, Callable): if not self.canonical_after_agent_callbacks:
return ret_event return ret_event
callback_context = CallbackContext(invocation_context) callback_context = CallbackContext(invocation_context)
after_agent_callback_content = self.after_agent_callback(
callback_context=callback_context
)
for callback in self.canonical_after_agent_callbacks:
after_agent_callback_content = callback(callback_context=callback_context)
if inspect.isawaitable(after_agent_callback_content): if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content after_agent_callback_content = await after_agent_callback_content
if after_agent_callback_content:
ret_event = Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
branch=invocation_context.branch,
content=after_agent_callback_content,
actions=callback_context._event_actions,
)
return ret_event
if after_agent_callback_content or callback_context.state.has_delta(): if callback_context.state.has_delta():
ret_event = Event( ret_event = Event(
invocation_id=invocation_context.invocation_id, invocation_id=invocation_context.invocation_id,
author=self.name, author=self.name,

View File

@ -191,14 +191,19 @@ class BaseLlmFlow(ABC):
llm_request: LlmRequest, llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]: ) -> AsyncGenerator[Event, None]:
"""Receive data from model and process events using BaseLlmConnection.""" """Receive data from model and process events using BaseLlmConnection."""
def get_author(llm_response): def get_author(llm_response):
"""Get the author of the event. """Get the author of the event.
When the model returns transcription, the author is "user". Otherwise, the When the model returns transcription, the author is "user". Otherwise, the
author is the agent. author is the agent.
""" """
if llm_response and llm_response.content and llm_response.content.role == "user": if (
return "user" llm_response
and llm_response.content
and llm_response.content.role == 'user'
):
return 'user'
else: else:
return invocation_context.agent.name return invocation_context.agent.name

View File

@ -1,209 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockAgentCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
class MockAsyncAgentCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAgentCallback(
mock_response='before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_agent_callback'),
]
@pytest.mark.asyncio
async def test_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAgentCallback(
mock_response='after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'after_agent_callback'),
]
@pytest.mark.asyncio
async def test_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAsyncAgentCallback(
mock_response='async_before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAsyncAgentCallback(
mock_response='async_after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'async_after_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]

View File

@ -14,10 +14,13 @@
"""Testings for the BaseAgent.""" """Testings for the BaseAgent."""
from enum import Enum
from functools import partial
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from unittest import mock
from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
@ -27,6 +30,7 @@ from google.genai import types
import pytest import pytest
import pytest_mock import pytest_mock
from typing_extensions import override from typing_extensions import override
from .. import utils
def _before_agent_callback_noop(callback_context: CallbackContext) -> None: def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
@ -266,6 +270,220 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
assert events[0].content.parts[0].text == 'agent run is bypassed.' assert events[0].content.parts[0].text == 'agent run is bypassed.'
class CallbackType(Enum):
SYNC = 1
ASYNC = 2
async def mock_async_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
def mock_sync_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
BEFORE_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
AFTER_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!', 'callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['Hello, world!', 'callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
BEFORE_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_before_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
AFTER_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_after_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop( async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,

View File

@ -1,210 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockBeforeModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAfterModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncBeforeModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncAfterModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_model_callback'),
]
@pytest.mark.asyncio
async def test_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAfterModelCallback(
mock_response='after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'after_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockAsyncBeforeModelCallback(
mock_response='async_before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_async_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAsyncAfterModelCallback(
mock_response='async_after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_after_model_callback'),
]