mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Support chaining for agent callbacks
(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756359693
This commit is contained in:
parent
a61d20e3df
commit
d45084f311
@ -86,6 +86,37 @@ async def after_model_callback(callback_context, llm_response):
|
||||
return None
|
||||
|
||||
|
||||
def after_agent_cb1(callback_context):
|
||||
print('@after_agent_cb1')
|
||||
|
||||
|
||||
def after_agent_cb2(callback_context):
|
||||
print('@after_agent_cb2')
|
||||
return types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
text='(stopped) after_agent_cb2',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def after_agent_cb3(callback_context):
|
||||
print('@after_agent_cb3')
|
||||
|
||||
|
||||
def before_agent_cb1(callback_context):
|
||||
print('@before_agent_cb1')
|
||||
|
||||
|
||||
def before_agent_cb2(callback_context):
|
||||
print('@before_agent_cb2')
|
||||
|
||||
|
||||
def before_agent_cb3(callback_context):
|
||||
print('@before_agent_cb3')
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-exp',
|
||||
name='data_processing_agent',
|
||||
@ -127,8 +158,12 @@ root_agent = Agent(
|
||||
),
|
||||
]
|
||||
),
|
||||
before_agent_callback=before_agent_callback,
|
||||
after_agent_callback=after_agent_callback,
|
||||
before_agent_callback=[
|
||||
before_agent_cb1,
|
||||
before_agent_cb2,
|
||||
before_agent_cb3,
|
||||
],
|
||||
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
|
||||
before_model_callback=before_model_callback,
|
||||
after_model_callback=after_model_callback,
|
||||
)
|
||||
|
@ -15,12 +15,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Awaitable, Union
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import final
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from opentelemetry import trace
|
||||
@ -29,6 +31,7 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import override
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ..events.event import Event
|
||||
from .callback_context import CallbackContext
|
||||
@ -38,14 +41,19 @@ if TYPE_CHECKING:
|
||||
|
||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||
|
||||
BeforeAgentCallback = Callable[
|
||||
_SingleAgentCallback: TypeAlias = Callable[
|
||||
[CallbackContext],
|
||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||
]
|
||||
|
||||
AfterAgentCallback = Callable[
|
||||
[CallbackContext],
|
||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||
BeforeAgentCallback: TypeAlias = Union[
|
||||
_SingleAgentCallback,
|
||||
list[_SingleAgentCallback],
|
||||
]
|
||||
|
||||
AfterAgentCallback: TypeAlias = Union[
|
||||
_SingleAgentCallback,
|
||||
list[_SingleAgentCallback],
|
||||
]
|
||||
|
||||
|
||||
@ -85,7 +93,10 @@ class BaseAgent(BaseModel):
|
||||
"""The sub-agents of this agent."""
|
||||
|
||||
before_agent_callback: Optional[BeforeAgentCallback] = None
|
||||
"""Callback signature that is invoked before the agent run.
|
||||
"""Callback or list of callbacks to be invoked before the agent run.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
@ -96,7 +107,10 @@ class BaseAgent(BaseModel):
|
||||
provided content will be returned to user.
|
||||
"""
|
||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
"""Callback or list of callbacks to be invoked after the agent run.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
@ -236,6 +250,30 @@ class BaseAgent(BaseModel):
|
||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||
return invocation_context
|
||||
|
||||
@property
|
||||
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
|
||||
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.before_agent_callback:
|
||||
return []
|
||||
if isinstance(self.before_agent_callback, list):
|
||||
return self.before_agent_callback
|
||||
return [self.before_agent_callback]
|
||||
|
||||
@property
|
||||
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
|
||||
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.after_agent_callback:
|
||||
return []
|
||||
if isinstance(self.after_agent_callback, list):
|
||||
return self.after_agent_callback
|
||||
return [self.after_agent_callback]
|
||||
|
||||
async def __handle_before_agent_callback(
|
||||
self, ctx: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
@ -246,17 +284,17 @@ class BaseAgent(BaseModel):
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.before_agent_callback, Callable):
|
||||
if not self.canonical_before_agent_callbacks:
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(ctx)
|
||||
before_agent_callback_content = self.before_agent_callback(
|
||||
|
||||
for callback in self.canonical_before_agent_callbacks:
|
||||
before_agent_callback_content = callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if inspect.isawaitable(before_agent_callback_content):
|
||||
before_agent_callback_content = await before_agent_callback_content
|
||||
|
||||
if before_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
@ -288,18 +326,26 @@ class BaseAgent(BaseModel):
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.after_agent_callback, Callable):
|
||||
if not self.canonical_after_agent_callbacks:
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
after_agent_callback_content = self.after_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
for callback in self.canonical_after_agent_callbacks:
|
||||
after_agent_callback_content = callback(callback_context=callback_context)
|
||||
if inspect.isawaitable(after_agent_callback_content):
|
||||
after_agent_callback_content = await after_agent_callback_content
|
||||
if after_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
branch=invocation_context.branch,
|
||||
content=after_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
return ret_event
|
||||
|
||||
if after_agent_callback_content or callback_context.state.has_delta():
|
||||
if callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
|
@ -191,14 +191,19 @@ class BaseLlmFlow(ABC):
|
||||
llm_request: LlmRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Receive data from model and process events using BaseLlmConnection."""
|
||||
|
||||
def get_author(llm_response):
|
||||
"""Get the author of the event.
|
||||
|
||||
When the model returns transcription, the author is "user". Otherwise, the
|
||||
author is the agent.
|
||||
"""
|
||||
if llm_response and llm_response.content and llm_response.content.role == "user":
|
||||
return "user"
|
||||
if (
|
||||
llm_response
|
||||
and llm_response.content
|
||||
and llm_response.content.role == 'user'
|
||||
):
|
||||
return 'user'
|
||||
else:
|
||||
return invocation_context.agent.name
|
||||
|
||||
|
@ -1,209 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from .. import utils
|
||||
|
||||
|
||||
class MockAgentCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
) -> types.Content:
|
||||
return types.Content(parts=[types.Part(text=self.mock_response)])
|
||||
|
||||
|
||||
class MockAsyncAgentCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
) -> types.Content:
|
||||
return types.Content(parts=[types.Part(text=self.mock_response)])
|
||||
|
||||
|
||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_agent_callback():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_agent_callback=MockAgentCallback(
|
||||
mock_response='before_agent_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'before_agent_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_agent_callback():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_agent_callback=MockAgentCallback(
|
||||
mock_response='after_agent_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'agent_response'),
|
||||
('root_agent', 'after_agent_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_agent_callback_noop():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_agent_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'agent_response'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_agent_callback_noop():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_agent_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'agent_response'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_before_agent_callback():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_agent_callback=MockAsyncAgentCallback(
|
||||
mock_response='async_before_agent_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'async_before_agent_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_after_agent_callback():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_agent_callback=MockAsyncAgentCallback(
|
||||
mock_response='async_after_agent_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'agent_response'),
|
||||
('root_agent', 'async_after_agent_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_before_agent_callback_noop():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_agent_callback=async_noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'agent_response'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_after_agent_callback_noop():
|
||||
responses = ['agent_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_agent_callback=async_noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'agent_response'),
|
||||
]
|
@ -14,10 +14,13 @@
|
||||
|
||||
"""Testings for the BaseAgent."""
|
||||
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import AsyncGenerator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from unittest import mock
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
@ -27,6 +30,7 @@ from google.genai import types
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from typing_extensions import override
|
||||
from .. import utils
|
||||
|
||||
|
||||
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
|
||||
@ -266,6 +270,220 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
|
||||
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
||||
|
||||
|
||||
class CallbackType(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
|
||||
|
||||
async def mock_async_agent_cb_side_effect(
|
||||
callback_context: CallbackContext,
|
||||
ret_value=None,
|
||||
):
|
||||
if ret_value:
|
||||
return types.Content(parts=[types.Part(text=ret_value)])
|
||||
return None
|
||||
|
||||
|
||||
def mock_sync_agent_cb_side_effect(
|
||||
callback_context: CallbackContext,
|
||||
ret_value=None,
|
||||
):
|
||||
if ret_value:
|
||||
return types.Content(parts=[types.Part(text=ret_value)])
|
||||
return None
|
||||
|
||||
|
||||
BEFORE_AGENT_CALLBACK_PARAMS = [
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
('callback_2_response', CallbackType.ASYNC),
|
||||
('callback_3_response', CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
['callback_2_response'],
|
||||
[1, 1, 0, 0],
|
||||
id='middle_async_callback_returns',
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
['Hello, world!'],
|
||||
[1, 1, 1, 1],
|
||||
id='all_callbacks_return_none',
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
('callback_1_response', CallbackType.SYNC),
|
||||
('callback_2_response', CallbackType.ASYNC),
|
||||
],
|
||||
['callback_1_response'],
|
||||
[1, 0],
|
||||
id='first_sync_callback_returns',
|
||||
),
|
||||
]
|
||||
|
||||
AFTER_AGENT_CALLBACK_PARAMS = [
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
('callback_2_response', CallbackType.ASYNC),
|
||||
('callback_3_response', CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
['Hello, world!', 'callback_2_response'],
|
||||
[1, 1, 0, 0],
|
||||
id='middle_async_callback_returns',
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
['Hello, world!'],
|
||||
[1, 1, 1, 1],
|
||||
id='all_callbacks_return_none',
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
('callback_1_response', CallbackType.SYNC),
|
||||
('callback_2_response', CallbackType.ASYNC),
|
||||
],
|
||||
['Hello, world!', 'callback_1_response'],
|
||||
[1, 0],
|
||||
id='first_sync_callback_returns',
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'callbacks, expected_responses, expected_calls',
|
||||
BEFORE_AGENT_CALLBACK_PARAMS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_agent_callbacks_chain(
|
||||
callbacks: List[tuple[str, int]],
|
||||
expected_responses: List[str],
|
||||
expected_calls: List[int],
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
mock_cbs = []
|
||||
for response, callback_type in callbacks:
|
||||
|
||||
if callback_type == CallbackType.ASYNC:
|
||||
mock_cb = mock.AsyncMock(
|
||||
side_effect=partial(
|
||||
mock_async_agent_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
else:
|
||||
mock_cb = mock.Mock(
|
||||
side_effect=partial(
|
||||
mock_sync_agent_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
mock_cbs.append(mock_cb)
|
||||
|
||||
agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
result = [e async for e in agent.run_async(parent_ctx)]
|
||||
assert utils.simplify_events(result) == [
|
||||
(f'{request.function.__name__}_test_agent', response)
|
||||
for response in expected_responses
|
||||
]
|
||||
|
||||
# Assert that the callbacks were called the expected number of times
|
||||
for i, mock_cb in enumerate(mock_cbs):
|
||||
expected_calls_count = expected_calls[i]
|
||||
if expected_calls_count == 1:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited_once()
|
||||
else:
|
||||
mock_cb.assert_called_once()
|
||||
elif expected_calls_count == 0:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_not_awaited()
|
||||
else:
|
||||
mock_cb.assert_not_called()
|
||||
else:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited(expected_calls_count)
|
||||
else:
|
||||
mock_cb.assert_called(expected_calls_count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'callbacks, expected_responses, expected_calls',
|
||||
AFTER_AGENT_CALLBACK_PARAMS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_agent_callbacks_chain(
|
||||
callbacks: List[tuple[str, int]],
|
||||
expected_responses: List[str],
|
||||
expected_calls: List[int],
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
mock_cbs = []
|
||||
for response, callback_type in callbacks:
|
||||
|
||||
if callback_type == CallbackType.ASYNC:
|
||||
mock_cb = mock.AsyncMock(
|
||||
side_effect=partial(
|
||||
mock_async_agent_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
else:
|
||||
mock_cb = mock.Mock(
|
||||
side_effect=partial(
|
||||
mock_sync_agent_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
mock_cbs.append(mock_cb)
|
||||
|
||||
agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
result = [e async for e in agent.run_async(parent_ctx)]
|
||||
assert utils.simplify_events(result) == [
|
||||
(f'{request.function.__name__}_test_agent', response)
|
||||
for response in expected_responses
|
||||
]
|
||||
|
||||
# Assert that the callbacks were called the expected number of times
|
||||
for i, mock_cb in enumerate(mock_cbs):
|
||||
expected_calls_count = expected_calls[i]
|
||||
if expected_calls_count == 1:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited_once()
|
||||
else:
|
||||
mock_cb.assert_called_once()
|
||||
elif expected_calls_count == 0:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_not_awaited()
|
||||
else:
|
||||
mock_cb.assert_not_called()
|
||||
else:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited(expected_calls_count)
|
||||
else:
|
||||
mock_cb.assert_called(expected_calls_count)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_after_agent_callback_noop(
|
||||
request: pytest.FixtureRequest,
|
||||
|
@ -1,210 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from .. import utils
|
||||
|
||||
|
||||
class MockBeforeModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAfterModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAsyncBeforeModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAsyncAfterModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'model_response'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=MockAfterModelCallback(
|
||||
mock_response='after_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'after_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_before_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockAsyncBeforeModelCallback(
|
||||
mock_response='async_before_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'async_before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_before_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=async_noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'model_response'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_after_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=MockAsyncAfterModelCallback(
|
||||
mock_response='async_after_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'async_after_model_callback'),
|
||||
]
|
Loading…
Reference in New Issue
Block a user