mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Support chaining for tool callbacks
(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756526507
This commit is contained in:
parent
0299020cc4
commit
2cbbf88135
@ -117,6 +117,31 @@ def before_agent_cb3(callback_context):
|
|||||||
print('@before_agent_cb3')
|
print('@before_agent_cb3')
|
||||||
|
|
||||||
|
|
||||||
|
def before_tool_cb1(tool, args, tool_context):
|
||||||
|
print('@before_tool_cb1')
|
||||||
|
|
||||||
|
|
||||||
|
def before_tool_cb2(tool, args, tool_context):
|
||||||
|
print('@before_tool_cb2')
|
||||||
|
|
||||||
|
|
||||||
|
def before_tool_cb3(tool, args, tool_context):
|
||||||
|
print('@before_tool_cb3')
|
||||||
|
|
||||||
|
|
||||||
|
def after_tool_cb1(tool, args, tool_context, tool_response):
|
||||||
|
print('@after_tool_cb1')
|
||||||
|
|
||||||
|
|
||||||
|
def after_tool_cb2(tool, args, tool_context, tool_response):
|
||||||
|
print('@after_tool_cb2')
|
||||||
|
return {'test': 'after_tool_cb2', 'response': tool_response}
|
||||||
|
|
||||||
|
|
||||||
|
def after_tool_cb3(tool, args, tool_context, tool_response):
|
||||||
|
print('@after_tool_cb3')
|
||||||
|
|
||||||
|
|
||||||
root_agent = Agent(
|
root_agent = Agent(
|
||||||
model='gemini-2.0-flash-exp',
|
model='gemini-2.0-flash-exp',
|
||||||
name='data_processing_agent',
|
name='data_processing_agent',
|
||||||
@ -166,4 +191,6 @@ root_agent = Agent(
|
|||||||
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
|
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
|
||||||
before_model_callback=before_model_callback,
|
before_model_callback=before_model_callback,
|
||||||
after_model_callback=after_model_callback,
|
after_model_callback=after_model_callback,
|
||||||
|
before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3],
|
||||||
|
after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3],
|
||||||
)
|
)
|
||||||
|
@ -83,7 +83,7 @@ async def main():
|
|||||||
print('------------------------------------')
|
print('------------------------------------')
|
||||||
await run_prompt(session_11, 'Hi')
|
await run_prompt(session_11, 'Hi')
|
||||||
await run_prompt(session_11, 'Roll a die with 100 sides')
|
await run_prompt(session_11, 'Roll a die with 100 sides')
|
||||||
await run_prompt(session_11, 'Roll a die again.')
|
await run_prompt(session_11, 'Roll a die again with 100 sides.')
|
||||||
await run_prompt(session_11, 'What numbers did I got?')
|
await run_prompt(session_11, 'What numbers did I got?')
|
||||||
await run_prompt_bytes(session_11, 'Hi bytes')
|
await run_prompt_bytes(session_11, 'Hi bytes')
|
||||||
print(
|
print(
|
||||||
@ -130,7 +130,7 @@ def main_sync():
|
|||||||
print('------------------------------------')
|
print('------------------------------------')
|
||||||
run_prompt(session_11, 'Hi')
|
run_prompt(session_11, 'Hi')
|
||||||
run_prompt(session_11, 'Roll a die with 100 sides.')
|
run_prompt(session_11, 'Roll a die with 100 sides.')
|
||||||
run_prompt(session_11, 'Roll a die again.')
|
run_prompt(session_11, 'Roll a die again with 100 sides.')
|
||||||
run_prompt(session_11, 'What numbers did I got?')
|
run_prompt(session_11, 'What numbers did I got?')
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print('------------------------------------')
|
print('------------------------------------')
|
||||||
|
@ -67,15 +67,26 @@ AfterModelCallback: TypeAlias = Union[
|
|||||||
list[_SingleAfterModelCallback],
|
list[_SingleAfterModelCallback],
|
||||||
]
|
]
|
||||||
|
|
||||||
BeforeToolCallback: TypeAlias = Callable[
|
_SingleBeforeToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext],
|
[BaseTool, dict[str, Any], ToolContext],
|
||||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
]
|
]
|
||||||
AfterToolCallback: TypeAlias = Callable[
|
|
||||||
|
BeforeToolCallback: TypeAlias = Union[
|
||||||
|
_SingleBeforeToolCallback,
|
||||||
|
list[_SingleBeforeToolCallback],
|
||||||
|
]
|
||||||
|
|
||||||
|
_SingleAfterToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
AfterToolCallback: TypeAlias = Union[
|
||||||
|
_SingleAfterToolCallback,
|
||||||
|
list[_SingleAfterToolCallback],
|
||||||
|
]
|
||||||
|
|
||||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||||
|
|
||||||
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
||||||
@ -214,7 +225,10 @@ class LlmAgent(BaseAgent):
|
|||||||
will be ignored and the provided content will be returned to user.
|
will be ignored and the provided content will be returned to user.
|
||||||
"""
|
"""
|
||||||
before_tool_callback: Optional[BeforeToolCallback] = None
|
before_tool_callback: Optional[BeforeToolCallback] = None
|
||||||
"""Called before the tool is called.
|
"""Callback or list of callbacks to be called before calling the tool.
|
||||||
|
|
||||||
|
When a list of callbacks is provided, the callbacks will be called in the
|
||||||
|
order they are listed until a callback does not return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool: The tool to be called.
|
tool: The tool to be called.
|
||||||
@ -226,7 +240,10 @@ class LlmAgent(BaseAgent):
|
|||||||
the framework will skip calling the actual tool.
|
the framework will skip calling the actual tool.
|
||||||
"""
|
"""
|
||||||
after_tool_callback: Optional[AfterToolCallback] = None
|
after_tool_callback: Optional[AfterToolCallback] = None
|
||||||
"""Called after the tool is called.
|
"""Callback or list of callbacks to be called after calling the tool.
|
||||||
|
|
||||||
|
When a list of callbacks is provided, the callbacks will be called in the
|
||||||
|
order they are listed until a callback does not return None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool: The tool to be called.
|
tool: The tool to be called.
|
||||||
@ -329,6 +346,34 @@ class LlmAgent(BaseAgent):
|
|||||||
return self.after_model_callback
|
return self.after_model_callback
|
||||||
return [self.after_model_callback]
|
return [self.after_model_callback]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def canonical_before_tool_callbacks(
|
||||||
|
self,
|
||||||
|
) -> list[BeforeToolCallback]:
|
||||||
|
"""The resolved self.before_tool_callback field as a list of BeforeToolCallback.
|
||||||
|
|
||||||
|
This method is only for use by Agent Development Kit.
|
||||||
|
"""
|
||||||
|
if not self.before_tool_callback:
|
||||||
|
return []
|
||||||
|
if isinstance(self.before_tool_callback, list):
|
||||||
|
return self.before_tool_callback
|
||||||
|
return [self.before_tool_callback]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def canonical_after_tool_callbacks(
|
||||||
|
self,
|
||||||
|
) -> list[AfterToolCallback]:
|
||||||
|
"""The resolved self.after_tool_callback field as a list of AfterToolCallback.
|
||||||
|
|
||||||
|
This method is only for use by Agent Development Kit.
|
||||||
|
"""
|
||||||
|
if not self.after_tool_callback:
|
||||||
|
return []
|
||||||
|
if isinstance(self.after_tool_callback, list):
|
||||||
|
return self.after_tool_callback
|
||||||
|
return [self.after_tool_callback]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_flow(self) -> BaseLlmFlow:
|
def _llm_flow(self) -> BaseLlmFlow:
|
||||||
if (
|
if (
|
||||||
|
@ -153,22 +153,22 @@ async def handle_function_calls_async(
|
|||||||
function_args = function_call.args or {}
|
function_args = function_call.args or {}
|
||||||
function_response: Optional[dict] = None
|
function_response: Optional[dict] = None
|
||||||
|
|
||||||
# before_tool_callback (sync or async)
|
for callback in agent.canonical_before_tool_callbacks:
|
||||||
if agent.before_tool_callback:
|
function_response = callback(
|
||||||
function_response = agent.before_tool_callback(
|
|
||||||
tool=tool, args=function_args, tool_context=tool_context
|
tool=tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(function_response):
|
if inspect.isawaitable(function_response):
|
||||||
function_response = await function_response
|
function_response = await function_response
|
||||||
|
if function_response:
|
||||||
|
break
|
||||||
|
|
||||||
if not function_response:
|
if not function_response:
|
||||||
function_response = await __call_tool_async(
|
function_response = await __call_tool_async(
|
||||||
tool, args=function_args, tool_context=tool_context
|
tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# after_tool_callback (sync or async)
|
for callback in agent.canonical_after_tool_callbacks:
|
||||||
if agent.after_tool_callback:
|
altered_function_response = callback(
|
||||||
altered_function_response = agent.after_tool_callback(
|
|
||||||
tool=tool,
|
tool=tool,
|
||||||
args=function_args,
|
args=function_args,
|
||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
@ -178,6 +178,7 @@ async def handle_function_calls_async(
|
|||||||
altered_function_response = await altered_function_response
|
altered_function_response = await altered_function_response
|
||||||
if altered_function_response is not None:
|
if altered_function_response is not None:
|
||||||
function_response = altered_function_response
|
function_response = altered_function_response
|
||||||
|
break
|
||||||
|
|
||||||
if tool.is_long_running:
|
if tool.is_long_running:
|
||||||
# Allow long running function to return None to not provide function response.
|
# Allow long running function to return None to not provide function response.
|
||||||
|
@ -12,20 +12,31 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
import pytest
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from google.adk.agents import Agent
|
from google.adk.agents import Agent
|
||||||
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||||
from google.adk.tools.function_tool import FunctionTool
|
from google.adk.tools.function_tool import FunctionTool
|
||||||
from google.adk.tools.tool_context import ToolContext
|
from google.adk.tools.tool_context import ToolContext
|
||||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
|
||||||
from google.adk.events.event import Event
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ... import utils
|
from ... import utils
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackType(Enum):
|
||||||
|
SYNC = 1
|
||||||
|
ASYNC = 2
|
||||||
|
|
||||||
|
|
||||||
class AsyncBeforeToolCallback:
|
class AsyncBeforeToolCallback:
|
||||||
|
|
||||||
def __init__(self, mock_response: Dict[str, Any]):
|
def __init__(self, mock_response: Dict[str, Any]):
|
||||||
@ -107,3 +118,184 @@ async def test_async_after_tool_callback():
|
|||||||
assert result_event is not None
|
assert result_event is not None
|
||||||
part = result_event.content.parts[0]
|
part = result_event.content.parts[0]
|
||||||
assert part.function_response.response == mock_resp
|
assert part.function_response.response == mock_resp
|
||||||
|
|
||||||
|
|
||||||
|
def mock_async_before_cb_side_effect(
|
||||||
|
tool: FunctionTool,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
ret_value: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return ret_value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def mock_sync_before_cb_side_effect(
|
||||||
|
tool: FunctionTool,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
ret_value: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return ret_value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_async_after_cb_side_effect(
|
||||||
|
tool: FunctionTool,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
tool_response: Dict[str, Any],
|
||||||
|
ret_value: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return ret_value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def mock_sync_after_cb_side_effect(
|
||||||
|
tool: FunctionTool,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
tool_response: Dict[str, Any],
|
||||||
|
ret_value: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
if ret_value:
|
||||||
|
return ret_value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
CALLBACK_PARAMS = [
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
({"test": "callback_2_response"}, CallbackType.ASYNC),
|
||||||
|
({"test": "callback_3_response"}, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
{"test": "callback_2_response"},
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
id="middle_async_callback_returns",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
(None, CallbackType.SYNC),
|
||||||
|
(None, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
{"initial": "response"},
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
id="all_callbacks_return_none",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
({"test": "callback_1_response"}, CallbackType.SYNC),
|
||||||
|
({"test": "callback_2_response"}, CallbackType.ASYNC),
|
||||||
|
],
|
||||||
|
{"test": "callback_1_response"},
|
||||||
|
[1, 0],
|
||||||
|
id="first_sync_callback_returns",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"callbacks, expected_response, expected_calls",
|
||||||
|
CALLBACK_PARAMS,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_tool_callbacks_chain(
|
||||||
|
callbacks: List[tuple[Optional[Dict[str, Any]], int]],
|
||||||
|
expected_response: Dict[str, Any],
|
||||||
|
expected_calls: List[int],
|
||||||
|
):
|
||||||
|
mock_before_cbs = []
|
||||||
|
for response, callback_type in callbacks:
|
||||||
|
if callback_type == CallbackType.ASYNC:
|
||||||
|
mock_cb = mock.AsyncMock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_async_before_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_cb = mock.Mock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_sync_before_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_before_cbs.append(mock_cb)
|
||||||
|
result_event = await invoke_tool_with_callbacks(before_cb=mock_before_cbs)
|
||||||
|
assert result_event is not None
|
||||||
|
part = result_event.content.parts[0]
|
||||||
|
assert part.function_response.response == expected_response
|
||||||
|
|
||||||
|
# Assert that the callbacks were called the expected number of times
|
||||||
|
for i, mock_cb in enumerate(mock_before_cbs):
|
||||||
|
expected_calls_count = expected_calls[i]
|
||||||
|
if expected_calls_count == 1:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited_once()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called_once()
|
||||||
|
elif expected_calls_count == 0:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_not_awaited()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_not_called()
|
||||||
|
else:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited(expected_calls_count)
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called(expected_calls_count)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"callbacks, expected_response, expected_calls",
|
||||||
|
CALLBACK_PARAMS,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_tool_callbacks_chain(
|
||||||
|
callbacks: List[tuple[Optional[Dict[str, Any]], int]],
|
||||||
|
expected_response: Dict[str, Any],
|
||||||
|
expected_calls: List[int],
|
||||||
|
):
|
||||||
|
mock_after_cbs = []
|
||||||
|
for response, callback_type in callbacks:
|
||||||
|
if callback_type == CallbackType.ASYNC:
|
||||||
|
mock_cb = mock.AsyncMock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_async_after_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_cb = mock.Mock(
|
||||||
|
side_effect=partial(
|
||||||
|
mock_sync_after_cb_side_effect, ret_value=response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_after_cbs.append(mock_cb)
|
||||||
|
result_event = await invoke_tool_with_callbacks(after_cb=mock_after_cbs)
|
||||||
|
assert result_event is not None
|
||||||
|
part = result_event.content.parts[0]
|
||||||
|
assert part.function_response.response == expected_response
|
||||||
|
|
||||||
|
# Assert that the callbacks were called the expected number of times
|
||||||
|
for i, mock_cb in enumerate(mock_after_cbs):
|
||||||
|
expected_calls_count = expected_calls[i]
|
||||||
|
if expected_calls_count == 1:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited_once()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called_once()
|
||||||
|
elif expected_calls_count == 0:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_not_awaited()
|
||||||
|
else:
|
||||||
|
mock_cb.assert_not_called()
|
||||||
|
else:
|
||||||
|
if isinstance(mock_cb, mock.AsyncMock):
|
||||||
|
mock_cb.assert_awaited(expected_calls_count)
|
||||||
|
else:
|
||||||
|
mock_cb.assert_called(expected_calls_count)
|
||||||
|
Loading…
Reference in New Issue
Block a user