adk-python/tests/unittests/agents/test_agent_callbacks.py
Selcuk Gun e4317c9eb7 Support chaining for model callbacks
(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 755565583
2025-05-06 16:16:02 -07:00

210 lines
5.4 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockAgentCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
class MockAsyncAgentCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAgentCallback(
mock_response='before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_agent_callback'),
]
@pytest.mark.asyncio
async def test_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAgentCallback(
mock_response='after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'after_agent_callback'),
]
@pytest.mark.asyncio
async def test_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAsyncAgentCallback(
mock_response='async_before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAsyncAgentCallback(
mock_response='async_after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'async_after_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]