mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Support chaining for agent callbacks
(before/after) agent callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756359693
This commit is contained in:
parent
a61d20e3df
commit
d45084f311
@ -86,6 +86,37 @@ async def after_model_callback(callback_context, llm_response):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def after_agent_cb1(callback_context):
|
||||||
|
print('@after_agent_cb1')
|
||||||
|
|
||||||
|
|
||||||
|
def after_agent_cb2(callback_context):
|
||||||
|
print('@after_agent_cb2')
|
||||||
|
return types.Content(
|
||||||
|
parts=[
|
||||||
|
types.Part(
|
||||||
|
text='(stopped) after_agent_cb2',
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def after_agent_cb3(callback_context):
|
||||||
|
print('@after_agent_cb3')
|
||||||
|
|
||||||
|
|
||||||
|
def before_agent_cb1(callback_context):
|
||||||
|
print('@before_agent_cb1')
|
||||||
|
|
||||||
|
|
||||||
|
def before_agent_cb2(callback_context):
|
||||||
|
print('@before_agent_cb2')
|
||||||
|
|
||||||
|
|
||||||
|
def before_agent_cb3(callback_context):
|
||||||
|
print('@before_agent_cb3')
|
||||||
|
|
||||||
|
|
||||||
root_agent = Agent(
|
root_agent = Agent(
|
||||||
model='gemini-2.0-flash-exp',
|
model='gemini-2.0-flash-exp',
|
||||||
name='data_processing_agent',
|
name='data_processing_agent',
|
||||||
@ -127,8 +158,12 @@ root_agent = Agent(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
before_agent_callback=before_agent_callback,
|
before_agent_callback=[
|
||||||
after_agent_callback=after_agent_callback,
|
before_agent_cb1,
|
||||||
|
before_agent_cb2,
|
||||||
|
before_agent_cb3,
|
||||||
|
],
|
||||||
|
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
|
||||||
before_model_callback=before_model_callback,
|
before_model_callback=before_model_callback,
|
||||||
after_model_callback=after_model_callback,
|
after_model_callback=after_model_callback,
|
||||||
)
|
)
|
||||||
|
@ -15,12 +15,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Awaitable, Union
|
from typing import Any
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
from typing import Awaitable
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing import final
|
from typing import final
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -29,6 +31,7 @@ from pydantic import ConfigDict
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
from .callback_context import CallbackContext
|
from .callback_context import CallbackContext
|
||||||
@ -38,14 +41,19 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||||
|
|
||||||
BeforeAgentCallback = Callable[
|
_SingleAgentCallback: TypeAlias = Callable[
|
||||||
[CallbackContext],
|
[CallbackContext],
|
||||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||||
]
|
]
|
||||||
|
|
||||||
AfterAgentCallback = Callable[
|
BeforeAgentCallback: TypeAlias = Union[
|
||||||
[CallbackContext],
|
_SingleAgentCallback,
|
||||||
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
list[_SingleAgentCallback],
|
||||||
|
]
|
||||||
|
|
||||||
|
AfterAgentCallback: TypeAlias = Union[
|
||||||
|
_SingleAgentCallback,
|
||||||
|
list[_SingleAgentCallback],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +93,10 @@ class BaseAgent(BaseModel):
|
|||||||
"""The sub-agents of this agent."""
|
"""The sub-agents of this agent."""
|
||||||
|
|
||||||
before_agent_callback: Optional[BeforeAgentCallback] = None
|
before_agent_callback: Optional[BeforeAgentCallback] = None
|
||||||
"""Callback signature that is invoked before the agent run.
|
"""Callback or list of callbacks to be invoked before the agent run.
|
||||||
|
|
||||||
|
When a list of callbacks is provided, the callbacks will be called in the
|
||||||
|
order they are listed until a callback does not return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: MUST be named 'callback_context' (enforced).
|
callback_context: MUST be named 'callback_context' (enforced).
|
||||||
@ -96,7 +107,10 @@ class BaseAgent(BaseModel):
|
|||||||
provided content will be returned to user.
|
provided content will be returned to user.
|
||||||
"""
|
"""
|
||||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||||
"""Callback signature that is invoked after the agent run.
|
"""Callback or list of callbacks to be invoked after the agent run.
|
||||||
|
|
||||||
|
When a list of callbacks is provided, the callbacks will be called in the
|
||||||
|
order they are listed until a callback does not return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_context: MUST be named 'callback_context' (enforced).
|
callback_context: MUST be named 'callback_context' (enforced).
|
||||||
@ -236,6 +250,30 @@ class BaseAgent(BaseModel):
|
|||||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||||
return invocation_context
|
return invocation_context
|
||||||
|
|
||||||
|
@property
|
||||||
|
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
|
||||||
|
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.
|
||||||
|
|
||||||
|
This method is only for use by Agent Development Kit.
|
||||||
|
"""
|
||||||
|
if not self.before_agent_callback:
|
||||||
|
return []
|
||||||
|
if isinstance(self.before_agent_callback, list):
|
||||||
|
return self.before_agent_callback
|
||||||
|
return [self.before_agent_callback]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
|
||||||
|
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.
|
||||||
|
|
||||||
|
This method is only for use by Agent Development Kit.
|
||||||
|
"""
|
||||||
|
if not self.after_agent_callback:
|
||||||
|
return []
|
||||||
|
if isinstance(self.after_agent_callback, list):
|
||||||
|
return self.after_agent_callback
|
||||||
|
return [self.after_agent_callback]
|
||||||
|
|
||||||
async def __handle_before_agent_callback(
|
async def __handle_before_agent_callback(
|
||||||
self, ctx: InvocationContext
|
self, ctx: InvocationContext
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
@ -246,17 +284,17 @@ class BaseAgent(BaseModel):
|
|||||||
"""
|
"""
|
||||||
ret_event = None
|
ret_event = None
|
||||||
|
|
||||||
if not isinstance(self.before_agent_callback, Callable):
|
if not self.canonical_before_agent_callbacks:
|
||||||
return ret_event
|
return ret_event
|
||||||
|
|
||||||
callback_context = CallbackContext(ctx)
|
callback_context = CallbackContext(ctx)
|
||||||
before_agent_callback_content = self.before_agent_callback(
|
|
||||||
|
for callback in self.canonical_before_agent_callbacks:
|
||||||
|
before_agent_callback_content = callback(
|
||||||
callback_context=callback_context
|
callback_context=callback_context
|
||||||
)
|
)
|
||||||
|
|
||||||
if inspect.isawaitable(before_agent_callback_content):
|
if inspect.isawaitable(before_agent_callback_content):
|
||||||
before_agent_callback_content = await before_agent_callback_content
|
before_agent_callback_content = await before_agent_callback_content
|
||||||
|
|
||||||
if before_agent_callback_content:
|
if before_agent_callback_content:
|
||||||
ret_event = Event(
|
ret_event = Event(
|
||||||
invocation_id=ctx.invocation_id,
|
invocation_id=ctx.invocation_id,
|
||||||
@ -288,18 +326,26 @@ class BaseAgent(BaseModel):
|
|||||||
"""
|
"""
|
||||||
ret_event = None
|
ret_event = None
|
||||||
|
|
||||||
if not isinstance(self.after_agent_callback, Callable):
|
if not self.canonical_after_agent_callbacks:
|
||||||
return ret_event
|
return ret_event
|
||||||
|
|
||||||
callback_context = CallbackContext(invocation_context)
|
callback_context = CallbackContext(invocation_context)
|
||||||
after_agent_callback_content = self.after_agent_callback(
|
|
||||||
callback_context=callback_context
|
|
||||||
)
|
|
||||||
|
|
||||||
|
for callback in self.canonical_after_agent_callbacks:
|
||||||
|
after_agent_callback_content = callback(callback_context=callback_context)
|
||||||
if inspect.isawaitable(after_agent_callback_content):
|
if inspect.isawaitable(after_agent_callback_content):
|
||||||
after_agent_callback_content = await after_agent_callback_content
|
after_agent_callback_content = await after_agent_callback_content
|
||||||
|
if after_agent_callback_content:
|
||||||
|
ret_event = Event(
|
||||||
|
invocation_id=invocation_context.invocation_id,
|
||||||
|
author=self.name,
|
||||||
|
branch=invocation_context.branch,
|
||||||
|
content=after_agent_callback_content,
|
||||||
|
actions=callback_context._event_actions,
|
||||||
|
)
|
||||||
|
return ret_event
|
||||||
|
|
||||||
if after_agent_callback_content or callback_context.state.has_delta():
|
if callback_context.state.has_delta():
|
||||||
ret_event = Event(
|
ret_event = Event(
|
||||||
invocation_id=invocation_context.invocation_id,
|
invocation_id=invocation_context.invocation_id,
|
||||||
author=self.name,
|
author=self.name,
|
||||||
|
@ -191,14 +191,19 @@ class BaseLlmFlow(ABC):
|
|||||||
llm_request: LlmRequest,
|
llm_request: LlmRequest,
|
||||||
) -> AsyncGenerator[Event, None]:
|
) -> AsyncGenerator[Event, None]:
|
||||||
"""Receive data from model and process events using BaseLlmConnection."""
|
"""Receive data from model and process events using BaseLlmConnection."""
|
||||||
|
|
||||||
def get_author(llm_response):
|
def get_author(llm_response):
|
||||||
"""Get the author of the event.
|
"""Get the author of the event.
|
||||||
|
|
||||||
When the model returns transcription, the author is "user". Otherwise, the
|
When the model returns transcription, the author is "user". Otherwise, the
|
||||||
author is the agent.
|
author is the agent.
|
||||||
"""
|
"""
|
||||||
if llm_response and llm_response.content and llm_response.content.role == "user":
|
if (
|
||||||
return "user"
|
llm_response
|
||||||
|
and llm_response.content
|
||||||
|
and llm_response.content.role == 'user'
|
||||||
|
):
|
||||||
|
return 'user'
|
||||||
else:
|
else:
|
||||||
return invocation_context.agent.name
|
return invocation_context.agent.name
|
||||||
|
|
||||||
|
@ -1,209 +0,0 @@
|
|||||||
# Copyright 2025 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from google.adk.agents.callback_context import CallbackContext
|
|
||||||
from google.adk.agents.llm_agent import Agent
|
|
||||||
from google.adk.models import LlmRequest
|
|
||||||
from google.adk.models import LlmResponse
|
|
||||||
from google.genai import types
|
|
||||||
from google.genai import types
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from .. import utils
|
|
||||||
|
|
||||||
|
|
||||||
class MockAgentCallback(BaseModel):
|
|
||||||
mock_response: str
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
callback_context: CallbackContext,
|
|
||||||
) -> types.Content:
|
|
||||||
return types.Content(parts=[types.Part(text=self.mock_response)])
|
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncAgentCallback(BaseModel):
|
|
||||||
mock_response: str
|
|
||||||
|
|
||||||
async def __call__(
|
|
||||||
self,
|
|
||||||
callback_context: CallbackContext,
|
|
||||||
) -> types.Content:
|
|
||||||
return types.Content(parts=[types.Part(text=self.mock_response)])
|
|
||||||
|
|
||||||
|
|
||||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_before_agent_callback():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_agent_callback=MockAgentCallback(
|
|
||||||
mock_response='before_agent_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'before_agent_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_after_agent_callback():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
after_agent_callback=MockAgentCallback(
|
|
||||||
mock_response='after_agent_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'agent_response'),
|
|
||||||
('root_agent', 'after_agent_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_before_agent_callback_noop():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_agent_callback=noop_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'agent_response'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_after_agent_callback_noop():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_agent_callback=noop_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'agent_response'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_before_agent_callback():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_agent_callback=MockAsyncAgentCallback(
|
|
||||||
mock_response='async_before_agent_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'async_before_agent_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_after_agent_callback():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
after_agent_callback=MockAsyncAgentCallback(
|
|
||||||
mock_response='async_after_agent_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'agent_response'),
|
|
||||||
('root_agent', 'async_after_agent_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_before_agent_callback_noop():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_agent_callback=async_noop_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'agent_response'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_after_agent_callback_noop():
|
|
||||||
responses = ['agent_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_agent_callback=async_noop_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'agent_response'),
|
|
||||||
]
|
|
@ -14,10 +14,13 @@
|
|||||||
|
|
||||||
"""Testings for the BaseAgent."""
|
"""Testings for the BaseAgent."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from unittest import mock
|
||||||
from google.adk.agents.base_agent import BaseAgent
|
from google.adk.agents.base_agent import BaseAgent
|
||||||
from google.adk.agents.callback_context import CallbackContext
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
from google.adk.agents.invocation_context import InvocationContext
|
from google.adk.agents.invocation_context import InvocationContext
|
||||||
@ -27,6 +30,7 @@ from google.genai import types
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_mock
|
import pytest_mock
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
|
|
||||||
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
|
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
|
||||||
@ -266,6 +270,220 @@ async def test_run_async_with_async_before_agent_callback_bypass_agent(
|
|||||||
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackType(Enum):
|
||||||
|
SYNC = 1
|
||||||
|
ASYNC = 2
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_async_agent_cb_side_effect(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
ret_value=None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return types.Content(parts=[types.Part(text=ret_value)])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def mock_sync_agent_cb_side_effect(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
ret_value=None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return types.Content(parts=[types.Part(text=ret_value)])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
BEFORE_AGENT_CALLBACK_PARAMS = [
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
('callback_2_response', CallbackType.ASYNC),
|
||||||
|
('callback_3_response', CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
['callback_2_response'],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
id='middle_async_callback_returns',
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
['Hello, world!'],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
id='all_callbacks_return_none',
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
('callback_1_response', CallbackType.SYNC),
|
||||||
|
('callback_2_response', CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
['callback_1_response'],
|
||||||
|
[1, 0],
|
||||||
|
id='first_sync_callback_returns',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
AFTER_AGENT_CALLBACK_PARAMS = [
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
('callback_2_response', CallbackType.ASYNC),
|
||||||
|
('callback_3_response', CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
['Hello, world!', 'callback_2_response'],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
id='middle_async_callback_returns',
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
['Hello, world!'],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
id='all_callbacks_return_none',
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
('callback_1_response', CallbackType.SYNC),
|
||||||
|
('callback_2_response', CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
['Hello, world!', 'callback_1_response'],
|
||||||
|
[1, 0],
|
||||||
|
id='first_sync_callback_returns',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'callbacks, expected_responses, expected_calls',
|
||||||
|
BEFORE_AGENT_CALLBACK_PARAMS,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_agent_callbacks_chain(
|
||||||
|
callbacks: List[tuple[str, int]],
|
||||||
|
expected_responses: List[str],
|
||||||
|
expected_calls: List[int],
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
):
|
||||||
|
mock_cbs = []
|
||||||
|
for response, callback_type in callbacks:
|
||||||
|
|
||||||
|
if callback_type == CallbackType.ASYNC:
|
||||||
|
mock_cb = mock.AsyncMock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_async_agent_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_cb = mock.Mock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_sync_agent_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_cbs.append(mock_cb)
|
||||||
|
|
||||||
|
agent = _TestingAgent(
|
||||||
|
name=f'{request.function.__name__}_test_agent',
|
||||||
|
before_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||||
|
)
|
||||||
|
parent_ctx = _create_parent_invocation_context(
|
||||||
|
request.function.__name__, agent
|
||||||
|
)
|
||||||
|
result = [e async for e in agent.run_async(parent_ctx)]
|
||||||
|
assert utils.simplify_events(result) == [
|
||||||
|
(f'{request.function.__name__}_test_agent', response)
|
||||||
|
for response in expected_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert that the callbacks were called the expected number of times
|
||||||
|
for i, mock_cb in enumerate(mock_cbs):
|
||||||
|
expected_calls_count = expected_calls[i]
|
||||||
|
if expected_calls_count == 1:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited_once()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called_once()
|
||||||
|
elif expected_calls_count == 0:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_not_awaited()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_not_called()
|
||||||
|
else:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited(expected_calls_count)
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called(expected_calls_count)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'callbacks, expected_responses, expected_calls',
|
||||||
|
AFTER_AGENT_CALLBACK_PARAMS,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_agent_callbacks_chain(
|
||||||
|
callbacks: List[tuple[str, int]],
|
||||||
|
expected_responses: List[str],
|
||||||
|
expected_calls: List[int],
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
):
|
||||||
|
mock_cbs = []
|
||||||
|
for response, callback_type in callbacks:
|
||||||
|
|
||||||
|
if callback_type == CallbackType.ASYNC:
|
||||||
|
mock_cb = mock.AsyncMock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_async_agent_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_cb = mock.Mock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_sync_agent_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_cbs.append(mock_cb)
|
||||||
|
|
||||||
|
agent = _TestingAgent(
|
||||||
|
name=f'{request.function.__name__}_test_agent',
|
||||||
|
after_agent_callback=[mock_cb for mock_cb in mock_cbs],
|
||||||
|
)
|
||||||
|
parent_ctx = _create_parent_invocation_context(
|
||||||
|
request.function.__name__, agent
|
||||||
|
)
|
||||||
|
result = [e async for e in agent.run_async(parent_ctx)]
|
||||||
|
assert utils.simplify_events(result) == [
|
||||||
|
(f'{request.function.__name__}_test_agent', response)
|
||||||
|
for response in expected_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert that the callbacks were called the expected number of times
|
||||||
|
for i, mock_cb in enumerate(mock_cbs):
|
||||||
|
expected_calls_count = expected_calls[i]
|
||||||
|
if expected_calls_count == 1:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited_once()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called_once()
|
||||||
|
elif expected_calls_count == 0:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_not_awaited()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_not_called()
|
||||||
|
else:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited(expected_calls_count)
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called(expected_calls_count)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_after_agent_callback_noop(
|
async def test_run_async_after_agent_callback_noop(
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
|
@ -1,210 +0,0 @@
|
|||||||
# Copyright 2025 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from google.adk.agents.callback_context import CallbackContext
|
|
||||||
from google.adk.agents.llm_agent import Agent
|
|
||||||
from google.adk.models import LlmRequest
|
|
||||||
from google.adk.models import LlmResponse
|
|
||||||
from google.genai import types
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from .. import utils
|
|
||||||
|
|
||||||
|
|
||||||
class MockBeforeModelCallback(BaseModel):
|
|
||||||
mock_response: str
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
callback_context: CallbackContext,
|
|
||||||
llm_request: LlmRequest,
|
|
||||||
) -> LlmResponse:
|
|
||||||
return LlmResponse(
|
|
||||||
content=utils.ModelContent(
|
|
||||||
[types.Part.from_text(text=self.mock_response)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockAfterModelCallback(BaseModel):
|
|
||||||
mock_response: str
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
callback_context: CallbackContext,
|
|
||||||
llm_response: LlmResponse,
|
|
||||||
) -> LlmResponse:
|
|
||||||
return LlmResponse(
|
|
||||||
content=utils.ModelContent(
|
|
||||||
[types.Part.from_text(text=self.mock_response)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncBeforeModelCallback(BaseModel):
|
|
||||||
mock_response: str
|
|
||||||
|
|
||||||
async def __call__(
|
|
||||||
self,
|
|
||||||
callback_context: CallbackContext,
|
|
||||||
llm_request: LlmRequest,
|
|
||||||
) -> LlmResponse:
|
|
||||||
return LlmResponse(
|
|
||||||
content=utils.ModelContent(
|
|
||||||
[types.Part.from_text(text=self.mock_response)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncAfterModelCallback(BaseModel):
|
|
||||||
mock_response: str
|
|
||||||
|
|
||||||
async def __call__(
|
|
||||||
self,
|
|
||||||
callback_context: CallbackContext,
|
|
||||||
llm_response: LlmResponse,
|
|
||||||
) -> LlmResponse:
|
|
||||||
return LlmResponse(
|
|
||||||
content=utils.ModelContent(
|
|
||||||
[types.Part.from_text(text=self.mock_response)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_before_model_callback():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_model_callback=MockBeforeModelCallback(
|
|
||||||
mock_response='before_model_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'before_model_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_before_model_callback_noop():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_model_callback=noop_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'model_response'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_after_model_callback():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
after_model_callback=MockAfterModelCallback(
|
|
||||||
mock_response='after_model_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'after_model_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_before_model_callback():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_model_callback=MockAsyncBeforeModelCallback(
|
|
||||||
mock_response='async_before_model_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'async_before_model_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_before_model_callback_noop():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_model_callback=async_noop_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'model_response'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_after_model_callback():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
after_model_callback=MockAsyncAfterModelCallback(
|
|
||||||
mock_response='async_after_model_callback'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'async_after_model_callback'),
|
|
||||||
]
|
|
Loading…
Reference in New Issue
Block a user