mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
Support chaining for tool callbacks
(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756526507
This commit is contained in:
parent
0299020cc4
commit
2cbbf88135
@ -117,6 +117,31 @@ def before_agent_cb3(callback_context):
|
||||
print('@before_agent_cb3')
|
||||
|
||||
|
||||
def before_tool_cb1(tool, args, tool_context):
|
||||
print('@before_tool_cb1')
|
||||
|
||||
|
||||
def before_tool_cb2(tool, args, tool_context):
|
||||
print('@before_tool_cb2')
|
||||
|
||||
|
||||
def before_tool_cb3(tool, args, tool_context):
|
||||
print('@before_tool_cb3')
|
||||
|
||||
|
||||
def after_tool_cb1(tool, args, tool_context, tool_response):
|
||||
print('@after_tool_cb1')
|
||||
|
||||
|
||||
def after_tool_cb2(tool, args, tool_context, tool_response):
|
||||
print('@after_tool_cb2')
|
||||
return {'test': 'after_tool_cb2', 'response': tool_response}
|
||||
|
||||
|
||||
def after_tool_cb3(tool, args, tool_context, tool_response):
|
||||
print('@after_tool_cb3')
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-exp',
|
||||
name='data_processing_agent',
|
||||
@ -166,4 +191,6 @@ root_agent = Agent(
|
||||
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
|
||||
before_model_callback=before_model_callback,
|
||||
after_model_callback=after_model_callback,
|
||||
before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3],
|
||||
after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3],
|
||||
)
|
||||
|
@ -83,7 +83,7 @@ async def main():
|
||||
print('------------------------------------')
|
||||
await run_prompt(session_11, 'Hi')
|
||||
await run_prompt(session_11, 'Roll a die with 100 sides')
|
||||
await run_prompt(session_11, 'Roll a die again.')
|
||||
await run_prompt(session_11, 'Roll a die again with 100 sides.')
|
||||
await run_prompt(session_11, 'What numbers did I got?')
|
||||
await run_prompt_bytes(session_11, 'Hi bytes')
|
||||
print(
|
||||
@ -130,7 +130,7 @@ def main_sync():
|
||||
print('------------------------------------')
|
||||
run_prompt(session_11, 'Hi')
|
||||
run_prompt(session_11, 'Roll a die with 100 sides.')
|
||||
run_prompt(session_11, 'Roll a die again.')
|
||||
run_prompt(session_11, 'Roll a die again with 100 sides.')
|
||||
run_prompt(session_11, 'What numbers did I got?')
|
||||
end_time = time.time()
|
||||
print('------------------------------------')
|
||||
|
@ -67,15 +67,26 @@ AfterModelCallback: TypeAlias = Union[
|
||||
list[_SingleAfterModelCallback],
|
||||
]
|
||||
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
_SingleBeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
]
|
||||
AfterToolCallback: TypeAlias = Callable[
|
||||
|
||||
BeforeToolCallback: TypeAlias = Union[
|
||||
_SingleBeforeToolCallback,
|
||||
list[_SingleBeforeToolCallback],
|
||||
]
|
||||
|
||||
_SingleAfterToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||
]
|
||||
|
||||
AfterToolCallback: TypeAlias = Union[
|
||||
_SingleAfterToolCallback,
|
||||
list[_SingleAfterToolCallback],
|
||||
]
|
||||
|
||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||
|
||||
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
||||
@ -214,7 +225,10 @@ class LlmAgent(BaseAgent):
|
||||
will be ignored and the provided content will be returned to user.
|
||||
"""
|
||||
before_tool_callback: Optional[BeforeToolCallback] = None
|
||||
"""Called before the tool is called.
|
||||
"""Callback or list of callbacks to be called before calling the tool.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
@ -226,7 +240,10 @@ class LlmAgent(BaseAgent):
|
||||
the framework will skip calling the actual tool.
|
||||
"""
|
||||
after_tool_callback: Optional[AfterToolCallback] = None
|
||||
"""Called after the tool is called.
|
||||
"""Callback or list of callbacks to be called after calling the tool.
|
||||
|
||||
When a list of callbacks is provided, the callbacks will be called in the
|
||||
order they are listed until a callback does not return None.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
@ -329,6 +346,34 @@ class LlmAgent(BaseAgent):
|
||||
return self.after_model_callback
|
||||
return [self.after_model_callback]
|
||||
|
||||
@property
|
||||
def canonical_before_tool_callbacks(
|
||||
self,
|
||||
) -> list[BeforeToolCallback]:
|
||||
"""The resolved self.before_tool_callback field as a list of BeforeToolCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.before_tool_callback:
|
||||
return []
|
||||
if isinstance(self.before_tool_callback, list):
|
||||
return self.before_tool_callback
|
||||
return [self.before_tool_callback]
|
||||
|
||||
@property
|
||||
def canonical_after_tool_callbacks(
|
||||
self,
|
||||
) -> list[AfterToolCallback]:
|
||||
"""The resolved self.after_tool_callback field as a list of AfterToolCallback.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if not self.after_tool_callback:
|
||||
return []
|
||||
if isinstance(self.after_tool_callback, list):
|
||||
return self.after_tool_callback
|
||||
return [self.after_tool_callback]
|
||||
|
||||
@property
|
||||
def _llm_flow(self) -> BaseLlmFlow:
|
||||
if (
|
||||
|
@ -153,22 +153,22 @@ async def handle_function_calls_async(
|
||||
function_args = function_call.args or {}
|
||||
function_response: Optional[dict] = None
|
||||
|
||||
# before_tool_callback (sync or async)
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
for callback in agent.canonical_before_tool_callbacks:
|
||||
function_response = callback(
|
||||
tool=tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
if inspect.isawaitable(function_response):
|
||||
function_response = await function_response
|
||||
if function_response:
|
||||
break
|
||||
|
||||
if not function_response:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
# after_tool_callback (sync or async)
|
||||
if agent.after_tool_callback:
|
||||
altered_function_response = agent.after_tool_callback(
|
||||
for callback in agent.canonical_after_tool_callbacks:
|
||||
altered_function_response = callback(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
@ -178,6 +178,7 @@ async def handle_function_calls_async(
|
||||
altered_function_response = await altered_function_response
|
||||
if altered_function_response is not None:
|
||||
function_response = altered_function_response
|
||||
break
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow long running function to return None to not provide function response.
|
||||
|
@ -12,20 +12,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.events.event import Event
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
class CallbackType(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
|
||||
|
||||
class AsyncBeforeToolCallback:
|
||||
|
||||
def __init__(self, mock_response: Dict[str, Any]):
|
||||
@ -107,3 +118,184 @@ async def test_async_after_tool_callback():
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == mock_resp
|
||||
|
||||
|
||||
def mock_async_before_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
def mock_sync_before_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
async def mock_async_after_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: Dict[str, Any],
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
def mock_sync_after_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: Dict[str, Any],
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
CALLBACK_PARAMS = [
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
({"test": "callback_2_response"}, CallbackType.ASYNC),
|
||||
({"test": "callback_3_response"}, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
{"test": "callback_2_response"},
|
||||
[1, 1, 0, 0],
|
||||
id="middle_async_callback_returns",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
{"initial": "response"},
|
||||
[1, 1, 1, 1],
|
||||
id="all_callbacks_return_none",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
({"test": "callback_1_response"}, CallbackType.SYNC),
|
||||
({"test": "callback_2_response"}, CallbackType.ASYNC),
|
||||
],
|
||||
{"test": "callback_1_response"},
|
||||
[1, 0],
|
||||
id="first_sync_callback_returns",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callbacks, expected_response, expected_calls",
|
||||
CALLBACK_PARAMS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_tool_callbacks_chain(
|
||||
callbacks: List[tuple[Optional[Dict[str, Any]], int]],
|
||||
expected_response: Dict[str, Any],
|
||||
expected_calls: List[int],
|
||||
):
|
||||
mock_before_cbs = []
|
||||
for response, callback_type in callbacks:
|
||||
if callback_type == CallbackType.ASYNC:
|
||||
mock_cb = mock.AsyncMock(
|
||||
side_effect=partial(
|
||||
mock_async_before_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
else:
|
||||
mock_cb = mock.Mock(
|
||||
side_effect=partial(
|
||||
mock_sync_before_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
mock_before_cbs.append(mock_cb)
|
||||
result_event = await invoke_tool_with_callbacks(before_cb=mock_before_cbs)
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == expected_response
|
||||
|
||||
# Assert that the callbacks were called the expected number of times
|
||||
for i, mock_cb in enumerate(mock_before_cbs):
|
||||
expected_calls_count = expected_calls[i]
|
||||
if expected_calls_count == 1:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited_once()
|
||||
else:
|
||||
mock_cb.assert_called_once()
|
||||
elif expected_calls_count == 0:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_not_awaited()
|
||||
else:
|
||||
mock_cb.assert_not_called()
|
||||
else:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited(expected_calls_count)
|
||||
else:
|
||||
mock_cb.assert_called(expected_calls_count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callbacks, expected_response, expected_calls",
|
||||
CALLBACK_PARAMS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_tool_callbacks_chain(
|
||||
callbacks: List[tuple[Optional[Dict[str, Any]], int]],
|
||||
expected_response: Dict[str, Any],
|
||||
expected_calls: List[int],
|
||||
):
|
||||
mock_after_cbs = []
|
||||
for response, callback_type in callbacks:
|
||||
if callback_type == CallbackType.ASYNC:
|
||||
mock_cb = mock.AsyncMock(
|
||||
side_effect=partial(
|
||||
mock_async_after_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
else:
|
||||
mock_cb = mock.Mock(
|
||||
side_effect=partial(
|
||||
mock_sync_after_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
mock_after_cbs.append(mock_cb)
|
||||
result_event = await invoke_tool_with_callbacks(after_cb=mock_after_cbs)
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == expected_response
|
||||
|
||||
# Assert that the callbacks were called the expected number of times
|
||||
for i, mock_cb in enumerate(mock_after_cbs):
|
||||
expected_calls_count = expected_calls[i]
|
||||
if expected_calls_count == 1:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited_once()
|
||||
else:
|
||||
mock_cb.assert_called_once()
|
||||
elif expected_calls_count == 0:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_not_awaited()
|
||||
else:
|
||||
mock_cb.assert_not_called()
|
||||
else:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited(expected_calls_count)
|
||||
else:
|
||||
mock_cb.assert_called(expected_calls_count)
|
||||
|
Loading…
Reference in New Issue
Block a user