Support chaining for model callbacks

(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 755565583
This commit is contained in:
Selcuk Gun 2025-05-06 16:15:33 -07:00 committed by Copybara-Service
parent 794a70edcd
commit e4317c9eb7
5 changed files with 731 additions and 23 deletions

View File

@ -47,15 +47,26 @@ from .readonly_context import ReadonlyContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SingleBeforeModelCallback: TypeAlias = Callable[
BeforeModelCallback: TypeAlias = Callable[
[CallbackContext, LlmRequest], [CallbackContext, LlmRequest],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]], Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
] ]
AfterModelCallback: TypeAlias = Callable[
BeforeModelCallback: TypeAlias = Union[
_SingleBeforeModelCallback,
list[_SingleBeforeModelCallback],
]
_SingleAfterModelCallback: TypeAlias = Callable[
[CallbackContext, LlmResponse], [CallbackContext, LlmResponse],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]], Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
] ]
AfterModelCallback: TypeAlias = Union[
_SingleAfterModelCallback,
list[_SingleAfterModelCallback],
]
BeforeToolCallback: TypeAlias = Callable[ BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext], [BaseTool, dict[str, Any], ToolContext],
Union[Awaitable[Optional[dict]], Optional[dict]], Union[Awaitable[Optional[dict]], Optional[dict]],
@ -174,7 +185,11 @@ class LlmAgent(BaseAgent):
# Callbacks - Start # Callbacks - Start
before_model_callback: Optional[BeforeModelCallback] = None before_model_callback: Optional[BeforeModelCallback] = None
"""Called before calling the LLM. """Callback or list of callbacks to be called before calling the LLM.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args: Args:
callback_context: CallbackContext, callback_context: CallbackContext,
llm_request: LlmRequest, The raw model request. Callback can mutate the llm_request: LlmRequest, The raw model request. Callback can mutate the
@ -185,7 +200,10 @@ class LlmAgent(BaseAgent):
skipped and the provided content will be returned to user. skipped and the provided content will be returned to user.
""" """
after_model_callback: Optional[AfterModelCallback] = None after_model_callback: Optional[AfterModelCallback] = None
"""Called after calling LLM. """Callback or list of callbacks to be called after calling the LLM.
When a list of callbacks is provided, the callbacks will be called in the
order they are listed until a callback does not return None.
Args: Args:
callback_context: CallbackContext, callback_context: CallbackContext,
@ -285,6 +303,32 @@ class LlmAgent(BaseAgent):
""" """
return [_convert_tool_union_to_tool(tool) for tool in self.tools] return [_convert_tool_union_to_tool(tool) for tool in self.tools]
@property
def canonical_before_model_callbacks(
self,
) -> list[_SingleBeforeModelCallback]:
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
This method is only for use by Agent Development Kit.
"""
if not self.before_model_callback:
return []
if isinstance(self.before_model_callback, list):
return self.before_model_callback
return [self.before_model_callback]
@property
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
This method is only for use by Agent Development Kit.
"""
if not self.after_model_callback:
return []
if isinstance(self.after_model_callback, list):
return self.after_model_callback
return [self.after_model_callback]
@property @property
def _llm_flow(self) -> BaseLlmFlow: def _llm_flow(self) -> BaseLlmFlow:
if ( if (

View File

@ -193,8 +193,9 @@ class BaseLlmFlow(ABC):
"""Receive data from model and process events using BaseLlmConnection.""" """Receive data from model and process events using BaseLlmConnection."""
def get_author(llm_response): def get_author(llm_response):
"""Get the author of the event. """Get the author of the event.
When the model returns transcription, the author is "user". Otherwise, the author is the agent. When the model returns transcription, the author is "user". Otherwise, the
author is the agent.
""" """
if llm_response and llm_response.content and llm_response.content.role == "user": if llm_response and llm_response.content and llm_response.content.role == "user":
return "user" return "user"
@ -509,20 +510,21 @@ class BaseLlmFlow(ABC):
if not isinstance(agent, LlmAgent): if not isinstance(agent, LlmAgent):
return return
if not agent.before_model_callback: if not agent.canonical_before_model_callbacks:
return return
callback_context = CallbackContext( callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions invocation_context, event_actions=model_response_event.actions
) )
before_model_callback_content = agent.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
if inspect.isawaitable(before_model_callback_content): for callback in agent.canonical_before_model_callbacks:
before_model_callback_content = await before_model_callback_content before_model_callback_content = callback(
callback_context=callback_context, llm_request=llm_request
return before_model_callback_content )
if inspect.isawaitable(before_model_callback_content):
before_model_callback_content = await before_model_callback_content
if before_model_callback_content:
return before_model_callback_content
async def _handle_after_model_callback( async def _handle_after_model_callback(
self, self,
@ -536,20 +538,21 @@ class BaseLlmFlow(ABC):
if not isinstance(agent, LlmAgent): if not isinstance(agent, LlmAgent):
return return
if not agent.after_model_callback: if not agent.canonical_after_model_callbacks:
return return
callback_context = CallbackContext( callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions invocation_context, event_actions=model_response_event.actions
) )
after_model_callback_content = agent.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
if inspect.isawaitable(after_model_callback_content): for callback in agent.canonical_after_model_callbacks:
after_model_callback_content = await after_model_callback_content after_model_callback_content = callback(
callback_context=callback_context, llm_response=llm_response
return after_model_callback_content )
if inspect.isawaitable(after_model_callback_content):
after_model_callback_content = await after_model_callback_content
if after_model_callback_content:
return after_model_callback_content
def _finalize_model_response_event( def _finalize_model_response_event(
self, self,

View File

@ -0,0 +1,209 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockAgentCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
class MockAsyncAgentCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text=self.mock_response)])
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAgentCallback(
mock_response='before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_agent_callback'),
]
@pytest.mark.asyncio
async def test_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAgentCallback(
mock_response='after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'after_agent_callback'),
]
@pytest.mark.asyncio
async def test_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=MockAsyncAgentCallback(
mock_response='async_before_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_agent_callback=MockAsyncAgentCallback(
mock_response='async_after_agent_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
('root_agent', 'async_after_agent_callback'),
]
@pytest.mark.asyncio
async def test_async_before_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]
@pytest.mark.asyncio
async def test_async_after_agent_callback_noop():
responses = ['agent_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_agent_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'agent_response'),
]

View File

@ -0,0 +1,242 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from functools import partial
from typing import Any
from typing import List
from typing import Optional
from unittest import mock
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class CallbackType(Enum):
SYNC = 1
ASYNC = 2
async def mock_async_before_cb_side_effect(
callback_context: CallbackContext,
llm_request: LlmRequest,
ret_value=None,
):
if ret_value:
return LlmResponse(
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
)
return None
def mock_sync_before_cb_side_effect(
callback_context: CallbackContext,
llm_request: LlmRequest,
ret_value=None,
):
if ret_value:
return LlmResponse(
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
)
return None
async def mock_async_after_cb_side_effect(
callback_context: CallbackContext,
llm_response: LlmResponse,
ret_value=None,
):
if ret_value:
return LlmResponse(
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
)
return None
def mock_sync_after_cb_side_effect(
callback_context: CallbackContext,
llm_response: LlmResponse,
ret_value=None,
):
if ret_value:
return LlmResponse(
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
)
return None
CALLBACK_PARAMS = [
pytest.param(
[
(None, CallbackType.SYNC),
("callback_2_response", CallbackType.ASYNC),
("callback_3_response", CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
"callback_2_response",
[1, 1, 0, 0],
id="middle_async_callback_returns",
),
pytest.param(
[
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
(None, CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
"model_response",
[1, 1, 1, 1],
id="all_callbacks_return_none",
),
pytest.param(
[
("callback_1_response", CallbackType.SYNC),
("callback_2_response", CallbackType.ASYNC),
],
"callback_1_response",
[1, 0],
id="first_sync_callback_returns",
),
]
@pytest.mark.parametrize(
"callbacks, expected_response, expected_calls",
CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_before_model_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_response: str,
expected_calls: List[int],
):
responses = ["model_response"]
mock_model = utils.MockModel.create(responses=responses)
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_before_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_before_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
# Create agent with multiple callbacks
agent = Agent(
name="root_agent",
model=mock_model,
before_model_callback=[mock_cb for mock_cb in mock_cbs],
)
runner = utils.TestInMemoryRunner(agent)
result = await runner.run_async_with_new_session("test")
assert utils.simplify_events(result) == [
("root_agent", expected_response),
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)
@pytest.mark.parametrize(
"callbacks, expected_response, expected_calls",
CALLBACK_PARAMS,
)
@pytest.mark.asyncio
async def test_after_model_callbacks_chain(
callbacks: List[tuple[str, int]],
expected_response: str,
expected_calls: List[int],
):
responses = ["model_response"]
mock_model = utils.MockModel.create(responses=responses)
mock_cbs = []
for response, callback_type in callbacks:
if callback_type == CallbackType.ASYNC:
mock_cb = mock.AsyncMock(
side_effect=partial(
mock_async_after_cb_side_effect, ret_value=response
)
)
else:
mock_cb = mock.Mock(
side_effect=partial(
mock_sync_after_cb_side_effect, ret_value=response
)
)
mock_cbs.append(mock_cb)
# Create agent with multiple callbacks
agent = Agent(
name="root_agent",
model=mock_model,
after_model_callback=[mock_cb for mock_cb in mock_cbs],
)
runner = utils.TestInMemoryRunner(agent)
result = await runner.run_async_with_new_session("test")
assert utils.simplify_events(result) == [
("root_agent", expected_response),
]
# Assert that the callbacks were called the expected number of times
for i, mock_cb in enumerate(mock_cbs):
expected_calls_count = expected_calls[i]
if expected_calls_count == 1:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited_once()
else:
mock_cb.assert_called_once()
elif expected_calls_count == 0:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_not_awaited()
else:
mock_cb.assert_not_called()
else:
if isinstance(mock_cb, mock.AsyncMock):
mock_cb.assert_awaited(expected_calls_count)
else:
mock_cb.assert_called(expected_calls_count)

View File

@ -0,0 +1,210 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from pydantic import BaseModel
import pytest
from .. import utils
class MockBeforeModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAfterModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncBeforeModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncAfterModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_model_callback'),
]
@pytest.mark.asyncio
async def test_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAfterModelCallback(
mock_response='after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'after_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockAsyncBeforeModelCallback(
mock_response='async_before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_async_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAsyncAfterModelCallback(
mock_response='async_after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_after_model_callback'),
]