mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 09:51:25 -06:00
Support chaining for model callbacks
(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 755565583
This commit is contained in:
parent
794a70edcd
commit
e4317c9eb7
@ -47,15 +47,26 @@ from .readonly_context import ReadonlyContext
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SingleBeforeModelCallback: TypeAlias = Callable[
|
||||||
BeforeModelCallback: TypeAlias = Callable[
|
|
||||||
[CallbackContext, LlmRequest],
|
[CallbackContext, LlmRequest],
|
||||||
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||||
]
|
]
|
||||||
AfterModelCallback: TypeAlias = Callable[
|
|
||||||
|
BeforeModelCallback: TypeAlias = Union[
|
||||||
|
_SingleBeforeModelCallback,
|
||||||
|
list[_SingleBeforeModelCallback],
|
||||||
|
]
|
||||||
|
|
||||||
|
_SingleAfterModelCallback: TypeAlias = Callable[
|
||||||
[CallbackContext, LlmResponse],
|
[CallbackContext, LlmResponse],
|
||||||
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
AfterModelCallback: TypeAlias = Union[
|
||||||
|
_SingleAfterModelCallback,
|
||||||
|
list[_SingleAfterModelCallback],
|
||||||
|
]
|
||||||
|
|
||||||
BeforeToolCallback: TypeAlias = Callable[
|
BeforeToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext],
|
[BaseTool, dict[str, Any], ToolContext],
|
||||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
@ -174,7 +185,11 @@ class LlmAgent(BaseAgent):
|
|||||||
|
|
||||||
# Callbacks - Start
|
# Callbacks - Start
|
||||||
before_model_callback: Optional[BeforeModelCallback] = None
|
before_model_callback: Optional[BeforeModelCallback] = None
|
||||||
"""Called before calling the LLM.
|
"""Callback or list of callbacks to be called before calling the LLM.
|
||||||
|
|
||||||
|
When a list of callbacks is provided, the callbacks will be called in the
|
||||||
|
order they are listed until a callback does not return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: CallbackContext,
|
callback_context: CallbackContext,
|
||||||
llm_request: LlmRequest, The raw model request. Callback can mutate the
|
llm_request: LlmRequest, The raw model request. Callback can mutate the
|
||||||
@ -185,7 +200,10 @@ class LlmAgent(BaseAgent):
|
|||||||
skipped and the provided content will be returned to user.
|
skipped and the provided content will be returned to user.
|
||||||
"""
|
"""
|
||||||
after_model_callback: Optional[AfterModelCallback] = None
|
after_model_callback: Optional[AfterModelCallback] = None
|
||||||
"""Called after calling LLM.
|
"""Callback or list of callbacks to be called after calling the LLM.
|
||||||
|
|
||||||
|
When a list of callbacks is provided, the callbacks will be called in the
|
||||||
|
order they are listed until a callback does not return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: CallbackContext,
|
callback_context: CallbackContext,
|
||||||
@ -285,6 +303,32 @@ class LlmAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
|
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def canonical_before_model_callbacks(
|
||||||
|
self,
|
||||||
|
) -> list[_SingleBeforeModelCallback]:
|
||||||
|
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
|
||||||
|
|
||||||
|
This method is only for use by Agent Development Kit.
|
||||||
|
"""
|
||||||
|
if not self.before_model_callback:
|
||||||
|
return []
|
||||||
|
if isinstance(self.before_model_callback, list):
|
||||||
|
return self.before_model_callback
|
||||||
|
return [self.before_model_callback]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
|
||||||
|
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
|
||||||
|
|
||||||
|
This method is only for use by Agent Development Kit.
|
||||||
|
"""
|
||||||
|
if not self.after_model_callback:
|
||||||
|
return []
|
||||||
|
if isinstance(self.after_model_callback, list):
|
||||||
|
return self.after_model_callback
|
||||||
|
return [self.after_model_callback]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_flow(self) -> BaseLlmFlow:
|
def _llm_flow(self) -> BaseLlmFlow:
|
||||||
if (
|
if (
|
||||||
|
@ -194,7 +194,8 @@ class BaseLlmFlow(ABC):
|
|||||||
def get_author(llm_response):
|
def get_author(llm_response):
|
||||||
"""Get the author of the event.
|
"""Get the author of the event.
|
||||||
|
|
||||||
When the model returns transcription, the author is "user". Otherwise, the author is the agent.
|
When the model returns transcription, the author is "user". Otherwise, the
|
||||||
|
author is the agent.
|
||||||
"""
|
"""
|
||||||
if llm_response and llm_response.content and llm_response.content.role == "user":
|
if llm_response and llm_response.content and llm_response.content.role == "user":
|
||||||
return "user"
|
return "user"
|
||||||
@ -509,20 +510,21 @@ class BaseLlmFlow(ABC):
|
|||||||
if not isinstance(agent, LlmAgent):
|
if not isinstance(agent, LlmAgent):
|
||||||
return
|
return
|
||||||
|
|
||||||
if not agent.before_model_callback:
|
if not agent.canonical_before_model_callbacks:
|
||||||
return
|
return
|
||||||
|
|
||||||
callback_context = CallbackContext(
|
callback_context = CallbackContext(
|
||||||
invocation_context, event_actions=model_response_event.actions
|
invocation_context, event_actions=model_response_event.actions
|
||||||
)
|
)
|
||||||
before_model_callback_content = agent.before_model_callback(
|
|
||||||
callback_context=callback_context, llm_request=llm_request
|
|
||||||
)
|
|
||||||
|
|
||||||
if inspect.isawaitable(before_model_callback_content):
|
for callback in agent.canonical_before_model_callbacks:
|
||||||
before_model_callback_content = await before_model_callback_content
|
before_model_callback_content = callback(
|
||||||
|
callback_context=callback_context, llm_request=llm_request
|
||||||
return before_model_callback_content
|
)
|
||||||
|
if inspect.isawaitable(before_model_callback_content):
|
||||||
|
before_model_callback_content = await before_model_callback_content
|
||||||
|
if before_model_callback_content:
|
||||||
|
return before_model_callback_content
|
||||||
|
|
||||||
async def _handle_after_model_callback(
|
async def _handle_after_model_callback(
|
||||||
self,
|
self,
|
||||||
@ -536,20 +538,21 @@ class BaseLlmFlow(ABC):
|
|||||||
if not isinstance(agent, LlmAgent):
|
if not isinstance(agent, LlmAgent):
|
||||||
return
|
return
|
||||||
|
|
||||||
if not agent.after_model_callback:
|
if not agent.canonical_after_model_callbacks:
|
||||||
return
|
return
|
||||||
|
|
||||||
callback_context = CallbackContext(
|
callback_context = CallbackContext(
|
||||||
invocation_context, event_actions=model_response_event.actions
|
invocation_context, event_actions=model_response_event.actions
|
||||||
)
|
)
|
||||||
after_model_callback_content = agent.after_model_callback(
|
|
||||||
callback_context=callback_context, llm_response=llm_response
|
|
||||||
)
|
|
||||||
|
|
||||||
if inspect.isawaitable(after_model_callback_content):
|
for callback in agent.canonical_after_model_callbacks:
|
||||||
after_model_callback_content = await after_model_callback_content
|
after_model_callback_content = callback(
|
||||||
|
callback_context=callback_context, llm_response=llm_response
|
||||||
return after_model_callback_content
|
)
|
||||||
|
if inspect.isawaitable(after_model_callback_content):
|
||||||
|
after_model_callback_content = await after_model_callback_content
|
||||||
|
if after_model_callback_content:
|
||||||
|
return after_model_callback_content
|
||||||
|
|
||||||
def _finalize_model_response_event(
|
def _finalize_model_response_event(
|
||||||
self,
|
self,
|
||||||
|
209
tests/unittests/agents/test_agent_callbacks.py
Normal file
209
tests/unittests/agents/test_agent_callbacks.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
|
from google.adk.agents.llm_agent import Agent
|
||||||
|
from google.adk.models import LlmRequest
|
||||||
|
from google.adk.models import LlmResponse
|
||||||
|
from google.genai import types
|
||||||
|
from google.genai import types
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
|
|
||||||
|
class MockAgentCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
) -> types.Content:
|
||||||
|
return types.Content(parts=[types.Part(text=self.mock_response)])
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncAgentCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
) -> types.Content:
|
||||||
|
return types.Content(parts=[types.Part(text=self.mock_response)])
|
||||||
|
|
||||||
|
|
||||||
|
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_agent_callback():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_agent_callback=MockAgentCallback(
|
||||||
|
mock_response='before_agent_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'before_agent_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_agent_callback():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
after_agent_callback=MockAgentCallback(
|
||||||
|
mock_response='after_agent_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'agent_response'),
|
||||||
|
('root_agent', 'after_agent_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_agent_callback_noop():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_agent_callback=noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'agent_response'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_agent_callback_noop():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_agent_callback=noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'agent_response'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_agent_callback():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_agent_callback=MockAsyncAgentCallback(
|
||||||
|
mock_response='async_before_agent_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'async_before_agent_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_after_agent_callback():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
after_agent_callback=MockAsyncAgentCallback(
|
||||||
|
mock_response='async_after_agent_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'agent_response'),
|
||||||
|
('root_agent', 'async_after_agent_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_agent_callback_noop():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_agent_callback=async_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'agent_response'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_after_agent_callback_noop():
|
||||||
|
responses = ['agent_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_agent_callback=async_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'agent_response'),
|
||||||
|
]
|
242
tests/unittests/agents/test_model_callback_chain.py
Normal file
242
tests/unittests/agents/test_model_callback_chain.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
|
from google.adk.agents.llm_agent import Agent
|
||||||
|
from google.adk.models import LlmRequest
|
||||||
|
from google.adk.models import LlmResponse
|
||||||
|
from google.genai import types
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackType(Enum):
|
||||||
|
SYNC = 1
|
||||||
|
ASYNC = 2
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_async_before_cb_side_effect(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_request: LlmRequest,
|
||||||
|
ret_value=None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def mock_sync_before_cb_side_effect(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_request: LlmRequest,
|
||||||
|
ret_value=None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_async_after_cb_side_effect(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_response: LlmResponse,
|
||||||
|
ret_value=None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def mock_sync_after_cb_side_effect(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_response: LlmResponse,
|
||||||
|
ret_value=None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent([types.Part.from_text(text=ret_value)])
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
CALLBACK_PARAMS = [
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
("callback_2_response", CallbackType.ASYNC),
|
||||||
|
("callback_3_response", CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
"callback_2_response",
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
id="middle_async_callback_returns",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
"model_response",
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
id="all_callbacks_return_none",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
("callback_1_response", CallbackType.SYNC),
|
||||||
|
("callback_2_response", CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
"callback_1_response",
|
||||||
|
[1, 0],
|
||||||
|
id="first_sync_callback_returns",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"callbacks, expected_response, expected_calls",
|
||||||
|
CALLBACK_PARAMS,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_model_callbacks_chain(
|
||||||
|
callbacks: List[tuple[str, int]],
|
||||||
|
expected_response: str,
|
||||||
|
expected_calls: List[int],
|
||||||
|
):
|
||||||
|
responses = ["model_response"]
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
|
||||||
|
mock_cbs = []
|
||||||
|
for response, callback_type in callbacks:
|
||||||
|
|
||||||
|
if callback_type == CallbackType.ASYNC:
|
||||||
|
mock_cb = mock.AsyncMock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_async_before_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_cb = mock.Mock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_sync_before_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_cbs.append(mock_cb)
|
||||||
|
# Create agent with multiple callbacks
|
||||||
|
agent = Agent(
|
||||||
|
name="root_agent",
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=[mock_cb for mock_cb in mock_cbs],
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
result = await runner.run_async_with_new_session("test")
|
||||||
|
assert utils.simplify_events(result) == [
|
||||||
|
("root_agent", expected_response),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert that the callbacks were called the expected number of times
|
||||||
|
for i, mock_cb in enumerate(mock_cbs):
|
||||||
|
expected_calls_count = expected_calls[i]
|
||||||
|
if expected_calls_count == 1:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited_once()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called_once()
|
||||||
|
elif expected_calls_count == 0:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_not_awaited()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_not_called()
|
||||||
|
else:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited(expected_calls_count)
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called(expected_calls_count)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"callbacks, expected_response, expected_calls",
|
||||||
|
CALLBACK_PARAMS,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_model_callbacks_chain(
|
||||||
|
callbacks: List[tuple[str, int]],
|
||||||
|
expected_response: str,
|
||||||
|
expected_calls: List[int],
|
||||||
|
):
|
||||||
|
responses = ["model_response"]
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
|
||||||
|
mock_cbs = []
|
||||||
|
for response, callback_type in callbacks:
|
||||||
|
|
||||||
|
if callback_type == CallbackType.ASYNC:
|
||||||
|
mock_cb = mock.AsyncMock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_async_after_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_cb = mock.Mock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_sync_after_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_cbs.append(mock_cb)
|
||||||
|
# Create agent with multiple callbacks
|
||||||
|
agent = Agent(
|
||||||
|
name="root_agent",
|
||||||
|
model=mock_model,
|
||||||
|
after_model_callback=[mock_cb for mock_cb in mock_cbs],
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
result = await runner.run_async_with_new_session("test")
|
||||||
|
assert utils.simplify_events(result) == [
|
||||||
|
("root_agent", expected_response),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert that the callbacks were called the expected number of times
|
||||||
|
for i, mock_cb in enumerate(mock_cbs):
|
||||||
|
expected_calls_count = expected_calls[i]
|
||||||
|
if expected_calls_count == 1:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited_once()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called_once()
|
||||||
|
elif expected_calls_count == 0:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_not_awaited()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_not_called()
|
||||||
|
else:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited(expected_calls_count)
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called(expected_calls_count)
|
210
tests/unittests/agents/test_model_callbacks.py
Normal file
210
tests/unittests/agents/test_model_callbacks.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
|
from google.adk.agents.llm_agent import Agent
|
||||||
|
from google.adk.models import LlmRequest
|
||||||
|
from google.adk.models import LlmResponse
|
||||||
|
from google.genai import types
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
|
|
||||||
|
class MockBeforeModelCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_request: LlmRequest,
|
||||||
|
) -> LlmResponse:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent(
|
||||||
|
[types.Part.from_text(text=self.mock_response)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAfterModelCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_response: LlmResponse,
|
||||||
|
) -> LlmResponse:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent(
|
||||||
|
[types.Part.from_text(text=self.mock_response)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncBeforeModelCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_request: LlmRequest,
|
||||||
|
) -> LlmResponse:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent(
|
||||||
|
[types.Part.from_text(text=self.mock_response)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncAfterModelCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_response: LlmResponse,
|
||||||
|
) -> LlmResponse:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent(
|
||||||
|
[types.Part.from_text(text=self.mock_response)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_model_callback():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=MockBeforeModelCallback(
|
||||||
|
mock_response='before_model_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'before_model_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_model_callback_noop():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'model_response'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_model_callback():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
after_model_callback=MockAfterModelCallback(
|
||||||
|
mock_response='after_model_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'after_model_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_model_callback():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=MockAsyncBeforeModelCallback(
|
||||||
|
mock_response='async_before_model_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'async_before_model_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_model_callback_noop():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=async_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'model_response'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_after_model_callback():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
after_model_callback=MockAsyncAfterModelCallback(
|
||||||
|
mock_response='async_after_model_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'async_after_model_callback'),
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user