mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
Support chaining for tool callbacks
(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync. PiperOrigin-RevId: 756526507
This commit is contained in:
committed by
Copybara-Service
parent
0299020cc4
commit
2cbbf88135
@@ -12,20 +12,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.events.event import Event
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
class CallbackType(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
|
||||
|
||||
class AsyncBeforeToolCallback:
|
||||
|
||||
def __init__(self, mock_response: Dict[str, Any]):
|
||||
@@ -107,3 +118,184 @@ async def test_async_after_tool_callback():
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == mock_resp
|
||||
|
||||
|
||||
def mock_async_before_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
def mock_sync_before_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
async def mock_async_after_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: Dict[str, Any],
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
def mock_sync_after_cb_side_effect(
|
||||
tool: FunctionTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: Dict[str, Any],
|
||||
ret_value: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if ret_value:
|
||||
return ret_value
|
||||
return None
|
||||
|
||||
|
||||
CALLBACK_PARAMS = [
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
({"test": "callback_2_response"}, CallbackType.ASYNC),
|
||||
({"test": "callback_3_response"}, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
{"test": "callback_2_response"},
|
||||
[1, 1, 0, 0],
|
||||
id="middle_async_callback_returns",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
(None, CallbackType.SYNC),
|
||||
(None, CallbackType.ASYNC),
|
||||
],
|
||||
{"initial": "response"},
|
||||
[1, 1, 1, 1],
|
||||
id="all_callbacks_return_none",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
({"test": "callback_1_response"}, CallbackType.SYNC),
|
||||
({"test": "callback_2_response"}, CallbackType.ASYNC),
|
||||
],
|
||||
{"test": "callback_1_response"},
|
||||
[1, 0],
|
||||
id="first_sync_callback_returns",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callbacks, expected_response, expected_calls",
|
||||
CALLBACK_PARAMS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_tool_callbacks_chain(
|
||||
callbacks: List[tuple[Optional[Dict[str, Any]], int]],
|
||||
expected_response: Dict[str, Any],
|
||||
expected_calls: List[int],
|
||||
):
|
||||
mock_before_cbs = []
|
||||
for response, callback_type in callbacks:
|
||||
if callback_type == CallbackType.ASYNC:
|
||||
mock_cb = mock.AsyncMock(
|
||||
side_effect=partial(
|
||||
mock_async_before_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
else:
|
||||
mock_cb = mock.Mock(
|
||||
side_effect=partial(
|
||||
mock_sync_before_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
mock_before_cbs.append(mock_cb)
|
||||
result_event = await invoke_tool_with_callbacks(before_cb=mock_before_cbs)
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == expected_response
|
||||
|
||||
# Assert that the callbacks were called the expected number of times
|
||||
for i, mock_cb in enumerate(mock_before_cbs):
|
||||
expected_calls_count = expected_calls[i]
|
||||
if expected_calls_count == 1:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited_once()
|
||||
else:
|
||||
mock_cb.assert_called_once()
|
||||
elif expected_calls_count == 0:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_not_awaited()
|
||||
else:
|
||||
mock_cb.assert_not_called()
|
||||
else:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited(expected_calls_count)
|
||||
else:
|
||||
mock_cb.assert_called(expected_calls_count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callbacks, expected_response, expected_calls",
|
||||
CALLBACK_PARAMS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_tool_callbacks_chain(
|
||||
callbacks: List[tuple[Optional[Dict[str, Any]], int]],
|
||||
expected_response: Dict[str, Any],
|
||||
expected_calls: List[int],
|
||||
):
|
||||
mock_after_cbs = []
|
||||
for response, callback_type in callbacks:
|
||||
if callback_type == CallbackType.ASYNC:
|
||||
mock_cb = mock.AsyncMock(
|
||||
side_effect=partial(
|
||||
mock_async_after_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
else:
|
||||
mock_cb = mock.Mock(
|
||||
side_effect=partial(
|
||||
mock_sync_after_cb_side_effect, ret_value=response
|
||||
)
|
||||
)
|
||||
mock_after_cbs.append(mock_cb)
|
||||
result_event = await invoke_tool_with_callbacks(after_cb=mock_after_cbs)
|
||||
assert result_event is not None
|
||||
part = result_event.content.parts[0]
|
||||
assert part.function_response.response == expected_response
|
||||
|
||||
# Assert that the callbacks were called the expected number of times
|
||||
for i, mock_cb in enumerate(mock_after_cbs):
|
||||
expected_calls_count = expected_calls[i]
|
||||
if expected_calls_count == 1:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited_once()
|
||||
else:
|
||||
mock_cb.assert_called_once()
|
||||
elif expected_calls_count == 0:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_not_awaited()
|
||||
else:
|
||||
mock_cb.assert_not_called()
|
||||
else:
|
||||
if isinstance(mock_cb, mock.AsyncMock):
|
||||
mock_cb.assert_awaited(expected_calls_count)
|
||||
else:
|
||||
mock_cb.assert_called(expected_calls_count)
|
||||
|
||||
Reference in New Issue
Block a user