Support chaining for agent callbacks

(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756359693
This commit is contained in:
Selcuk Gun
2025-05-08 10:08:51 -07:00
committed by Copybara-Service
parent a61d20e3df
commit d45084f311
6 changed files with 339 additions and 454 deletions

View File

@@ -1,209 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockAgentCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
class MockAsyncAgentCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAgentCallback(
mock_response='before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_agent_callback'),
]
@pytest.mark.asyncio
async def test_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAgentCallback(
mock_response='after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'after_agent_callback'),
]
@pytest.mark.asyncio
async def test_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAsyncAgentCallback(
mock_response='async_before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAsyncAgentCallback(
mock_response='async_after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'async_after_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]

View File

@@ -14,10 +14,13 @@
"""Testings for the BaseAgent."""
from enum import Enum
from functools import partial
from typing import AsyncGenerator
from typing import List
from typing import Optional
from typing import Union
from unittest import mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
@@ -27,6 +30,7 @@ from google.genai import types
import pytest
import pytest_mock
from typing_extensions import override
from .. import utils
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
@@ -266,6 +270,220 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
assert events[0].content.parts[0].text == 'agent run is bypassed.'
class CallbackType(Enum):
SYNC = 1
ASYNC = 2
async def mock_async_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
def mock_sync_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
BEFORE_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
AFTER_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!', 'callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['Hello, world!', 'callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
BEFORE_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_before_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
AFTER_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_after_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest,

View File

@@ -1,210 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockBeforeModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAfterModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncBeforeModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncAfterModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_model_callback'),
]
@pytest.mark.asyncio
async def test_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAfterModelCallback(
mock_response='after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'after_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockAsyncBeforeModelCallback(
mock_response='async_before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_async_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAsyncAfterModelCallback(
mock_response='async_after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_after_model_callback'),
]