adk-python/tests/unittests/agents/test_base_agent.py
Google Team Member ec8bc7387c fix: ParallelAgent should only append to its immediate sub-agent, not transitive descendants
Restores automatic conversation history sharing for sequential/loop sub-agents.

PiperOrigin-RevId: 766742380
2025-06-03 16:55:51 -07:00

760 lines
22 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testings for the BaseAgent."""
from enum import Enum
from functools import partial
from typing import AsyncGenerator
from typing import List
from typing import Optional
from typing import Union
from unittest import mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
import pytest
import pytest_mock
from typing_extensions import override
from .. import testing_utils
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
pass
async def _async_before_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
async def _async_before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
pass
async def _async_after_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(
parts=[types.Part(text='Agent reply from after agent callback.')]
)
async def _async_after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(
parts=[types.Part(text='Agent reply from after agent callback.')]
)
class _IncompleteAgent(BaseAgent):
pass
class _TestingAgent(BaseAgent):
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
content=types.Content(parts=[types.Part(text='Hello, world!')]),
)
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
author=self.name,
invocation_id=ctx.invocation_id,
branch=ctx.branch,
content=types.Content(parts=[types.Part(text='Hello, live!')]),
)
async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent, branch: Optional[str] = None
) -> InvocationContext:
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
invocation_id=f'{test_name}_invocation_id',
branch=branch,
agent=agent,
session=session,
session_service=session_service,
)
def test_invalid_agent_name():
with pytest.raises(ValueError):
_ = _TestingAgent(name='not an identifier')
@pytest.mark.asyncio
async def test_run_async(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
events = [e async for e in agent.run_async(parent_ctx)]
assert len(events) == 1
assert events[0].author == agent.name
assert events[0].content.parts[0].text == 'Hello, world!'
@pytest.mark.asyncio
async def test_run_async_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch'
)
events = [e async for e in agent.run_async(parent_ctx)]
assert len(events) == 1
assert events[0].author == agent.name
assert events[0].content.parts[0].text == 'Hello, world!'
assert events[0].branch == 'parent_branch'
@pytest.mark.asyncio
async def test_run_async_before_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
) -> Union[types.Content, None]:
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_noop,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
_ = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
_, kwargs = spy_before_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
) -> Union[types.Content, None]:
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
_ = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
_, kwargs = spy_before_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_before_agent_callback_bypass_agent,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
spy_run_async_impl.assert_not_called()
assert len(events) == 1
assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
spy_run_async_impl.assert_not_called()
assert len(events) == 1
assert events[0].content.parts[0].text == 'agent run is bypassed.'
class CallbackType(Enum):
SYNC = 1
ASYNC = 2
async def mock_async_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
def mock_sync_agent_cb_side_effect(
callback_context: CallbackContext,
ret_value=None,
):
if ret_value:
return types.Content(parts=[types.Part(text=ret_value)])
return None
BEFORE_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
AFTER_AGENT_CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!', 'callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!'],
[1, 1, 1, 1],
id='all_callbacks_return_none',
),
pytest.param(
[
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['Hello, world!', 'callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
]
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
BEFORE_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_before_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert testing_utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.parametrize(
'callbacks, expected_responses, expected_calls',
AFTER_AGENT_CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_after_agent_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_responses: List[str],
expected_calls: List[int],
request: pytest.FixtureRequest,
):
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_agent_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_agent_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert testing_utils.simplify_events(result) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_noop,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_after_agent_callback.assert_called_once()
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_after_agent_callback.assert_called_once()
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_append_agent_reply,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
assert len(events) == 2
assert events[1].author == agent.name
assert (
events[1].content.parts[0].text
== 'Agent reply from after agent callback.'
)
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
assert len(events) == 2
assert events[1].author == agent.name
assert (
events[1].content.parts[0].text
== 'Agent reply from after agent callback.'
)
@pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
with pytest.raises(NotImplementedError):
[e async for e in agent.run_async(parent_ctx)]
@pytest.mark.asyncio
async def test_run_live(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
events = [e async for e in agent.run_live(parent_ctx)]
assert len(events) == 1
assert events[0].author == agent.name
assert events[0].content.parts[0].text == 'Hello, live!'
@pytest.mark.asyncio
async def test_run_live_with_branch(request: pytest.FixtureRequest):
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent, branch='parent_branch'
)
events = [e async for e in agent.run_live(parent_ctx)]
assert len(events) == 1
assert events[0].author == agent.name
assert events[0].content.parts[0].text == 'Hello, live!'
assert events[0].branch == 'parent_branch'
@pytest.mark.asyncio
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)
with pytest.raises(NotImplementedError):
[e async for e in agent.run_live(parent_ctx)]
def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
sub_agents: list[BaseAgent] = [
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
_TestingAgent(name=f'{request.function.__name__}_sub_agent_2'),
]
parent = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=sub_agents,
)
for sub_agent in sub_agents:
assert sub_agent.parent_agent == parent
def test_find_agent(request: pytest.FixtureRequest):
grand_sub_agent_1 = _TestingAgent(
name=f'{request.function.__name__}__grand_sub_agent_1'
)
grand_sub_agent_2 = _TestingAgent(
name=f'{request.function.__name__}__grand_sub_agent_2'
)
sub_agent_1 = _TestingAgent(
name=f'{request.function.__name__}_sub_agent_1',
sub_agents=[grand_sub_agent_1],
)
sub_agent_2 = _TestingAgent(
name=f'{request.function.__name__}_sub_agent_2',
sub_agents=[grand_sub_agent_2],
)
parent = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=[sub_agent_1, sub_agent_2],
)
assert parent.find_agent(parent.name) == parent
assert parent.find_agent(sub_agent_1.name) == sub_agent_1
assert parent.find_agent(sub_agent_2.name) == sub_agent_2
assert parent.find_agent(grand_sub_agent_1.name) == grand_sub_agent_1
assert parent.find_agent(grand_sub_agent_2.name) == grand_sub_agent_2
assert sub_agent_1.find_agent(grand_sub_agent_1.name) == grand_sub_agent_1
assert sub_agent_1.find_agent(grand_sub_agent_2.name) is None
assert sub_agent_2.find_agent(grand_sub_agent_1.name) is None
assert sub_agent_2.find_agent(sub_agent_2.name) == sub_agent_2
assert parent.find_agent('not_exist') is None
def test_find_sub_agent(request: pytest.FixtureRequest):
grand_sub_agent_1 = _TestingAgent(
name=f'{request.function.__name__}__grand_sub_agent_1'
)
grand_sub_agent_2 = _TestingAgent(
name=f'{request.function.__name__}__grand_sub_agent_2'
)
sub_agent_1 = _TestingAgent(
name=f'{request.function.__name__}_sub_agent_1',
sub_agents=[grand_sub_agent_1],
)
sub_agent_2 = _TestingAgent(
name=f'{request.function.__name__}_sub_agent_2',
sub_agents=[grand_sub_agent_2],
)
parent = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=[sub_agent_1, sub_agent_2],
)
assert parent.find_sub_agent(sub_agent_1.name) == sub_agent_1
assert parent.find_sub_agent(sub_agent_2.name) == sub_agent_2
assert parent.find_sub_agent(grand_sub_agent_1.name) == grand_sub_agent_1
assert parent.find_sub_agent(grand_sub_agent_2.name) == grand_sub_agent_2
assert sub_agent_1.find_sub_agent(grand_sub_agent_1.name) == grand_sub_agent_1
assert sub_agent_1.find_sub_agent(grand_sub_agent_2.name) is None
assert sub_agent_2.find_sub_agent(grand_sub_agent_1.name) is None
assert sub_agent_2.find_sub_agent(grand_sub_agent_2.name) == grand_sub_agent_2
assert parent.find_sub_agent(parent.name) is None
assert parent.find_sub_agent('not_exist') is None
def test_root_agent(request: pytest.FixtureRequest):
grand_sub_agent_1 = _TestingAgent(
name=f'{request.function.__name__}__grand_sub_agent_1'
)
grand_sub_agent_2 = _TestingAgent(
name=f'{request.function.__name__}__grand_sub_agent_2'
)
sub_agent_1 = _TestingAgent(
name=f'{request.function.__name__}_sub_agent_1',
sub_agents=[grand_sub_agent_1],
)
sub_agent_2 = _TestingAgent(
name=f'{request.function.__name__}_sub_agent_2',
sub_agents=[grand_sub_agent_2],
)
parent = _TestingAgent(
name=f'{request.function.__name__}_parent',
sub_agents=[sub_agent_1, sub_agent_2],
)
assert parent.root_agent == parent
assert sub_agent_1.root_agent == parent
assert sub_agent_2.root_agent == parent
assert grand_sub_agent_1.root_agent == parent
assert grand_sub_agent_2.root_agent == parent
def test_set_parent_agent_for_sub_agent_twice(
request: pytest.FixtureRequest,
):
sub_agent = _TestingAgent(name=f'{request.function.__name__}_sub_agent')
_ = _TestingAgent(
name=f'{request.function.__name__}_parent_1',
sub_agents=[sub_agent],
)
with pytest.raises(ValueError):
_ = _TestingAgent(
name=f'{request.function.__name__}_parent_2',
sub_agents=[sub_agent],
)