Support chaining for model callbacks

(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 755565583
This commit is contained in:
Selcuk Gun
2025-05-06 16:15:33 -07:00
committed by Copybara-Service
parent 794a70edcd
commit e4317c9eb7
5 changed files with 731 additions and 23 deletions
@@ -0,0 +1,209 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockAgentCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
class MockAsyncAgentCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAgentCallback(
mock_response='before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_agent_callback'),
]
@pytest.mark.asyncio
async def test_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAgentCallback(
mock_response='after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'after_agent_callback'),
]
@pytest.mark.asyncio
async def test_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAsyncAgentCallback(
mock_response='async_before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAsyncAgentCallback(
mock_response='async_after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'async_after_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]