Support chaining for agent callbacks

(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756359693
This commit is contained in:
Selcuk Gun
2025-05-08 10:08:51 -07:00
committed by Copybara-Service
parent a61d20e3df
commit d45084f311
6 changed files with 339 additions and 454 deletions

View File

@@ -14,10 +14,13 @@
"""Testings for the BaseAgent."""
from enum import Enum
from functools import partial
from typing import AsyncGenerator
from typing import List
from typing import Optional
from typing import Union
from unittest import mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
@@ -27,6 +30,7 @@ from google.genai import types
import pytest
import pytest_mock
from typing_extensions import override
from .. import utils
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
@@ -266,6 +270,220 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
assert events[0].content.parts[0].text == 'agent run is bypassed.'
class CallbackType(Enum):
SYNC = 1
ASYNC = 2
async def mock_async_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
def mock_sync_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
BEFORE_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
AFTER_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!', 'callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['Hello, world!', 'callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
BEFORE_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_before_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
AFTER_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_after_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest,