mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 09:51:25 -06:00
Currently if a model calls a FunctionTool without all the mandatory parameters, the code will just break. This change basically adds the capability for the FunctionTool to identify if the model is missing required arguments, and in that case, instead of breaking the execution, it provides a error message to the model so it could fix the request and retry.
PiperOrigin-RevId: 751023475
This commit is contained in:
parent
e6109b1dd6
commit
f872577f68
@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
|
|||||||
if 'tool_context' in signature.parameters:
|
if 'tool_context' in signature.parameters:
|
||||||
args_to_call['tool_context'] = tool_context
|
args_to_call['tool_context'] = tool_context
|
||||||
|
|
||||||
|
# Before invoking the function, we check for if the list of args passed in
|
||||||
|
# has all the mandatory arguments or not.
|
||||||
|
# If the check fails, then we don't invoke the tool and let the Agent know
|
||||||
|
# that there was a missing a input parameter. This will basically help
|
||||||
|
# the underlying model fix the issue and retry.
|
||||||
|
mandatory_args = self._get_mandatory_args()
|
||||||
|
missing_mandatory_args = [
|
||||||
|
arg for arg in mandatory_args if arg not in args_to_call
|
||||||
|
]
|
||||||
|
|
||||||
|
if missing_mandatory_args:
|
||||||
|
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
|
||||||
|
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
|
||||||
|
{missing_mandatory_args_str}
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
return {'error': error_str}
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self.func):
|
if inspect.iscoroutinefunction(self.func):
|
||||||
return await self.func(**args_to_call) or {}
|
return await self.func(**args_to_call) or {}
|
||||||
else:
|
else:
|
||||||
@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
|
|||||||
args_to_call['tool_context'] = tool_context
|
args_to_call['tool_context'] = tool_context
|
||||||
async for item in self.func(**args_to_call):
|
async for item in self.func(**args_to_call):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
def _get_mandatory_args(
|
||||||
|
self,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Identifies mandatory parameters (those without default values) for a function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings, where each string is the name of a mandatory parameter.
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(self.func)
|
||||||
|
mandatory_params = []
|
||||||
|
|
||||||
|
for name, param in signature.parameters.items():
|
||||||
|
# A parameter is mandatory if:
|
||||||
|
# 1. It has no default value (param.default is inspect.Parameter.empty)
|
||||||
|
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
|
||||||
|
#
|
||||||
|
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
|
||||||
|
if param.default == inspect.Parameter.empty and param.kind not in (
|
||||||
|
inspect.Parameter.VAR_POSITIONAL,
|
||||||
|
inspect.Parameter.VAR_KEYWORD,
|
||||||
|
):
|
||||||
|
mandatory_params.append(name)
|
||||||
|
|
||||||
|
return mandatory_params
|
||||||
|
238
tests/unittests/tools/test_function_tool.py
Normal file
238
tests/unittests/tools/test_function_tool.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from google.adk.tools.function_tool import FunctionTool
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_no_args():
|
||||||
|
"""Function for testing with no args."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function_for_testing_with_1_arg_and_tool_context(
|
||||||
|
arg1, tool_context
|
||||||
|
):
|
||||||
|
"""Async function for testing with 1 arge and tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert tool_context
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
||||||
|
"""Async function for testing with 2 arge and no tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert arg2
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
|
||||||
|
"""Function for testing with 1 arge and tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert tool_context
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
||||||
|
"""Function for testing with 2 arge and no tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert arg2
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function_for_testing_with_4_arg_and_no_tool_context(
|
||||||
|
arg1, arg2, arg3, arg4
|
||||||
|
):
|
||||||
|
"""Async function for testing with 4 args."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4):
|
||||||
|
"""Function for testing with 4 args."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
"""Test that the FunctionTool is initialized correctly."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_no_args)
|
||||||
|
assert tool.name == "function_for_testing_with_no_args"
|
||||||
|
assert tool.description == "Function for testing with no args."
|
||||||
|
assert tool.func == function_for_testing_with_no_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_tool_context_async_func():
|
||||||
|
"""Test that run_async calls the function with tool_context when tool_context is in signature (async function)."""
|
||||||
|
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context)
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_without_tool_context_async_func():
|
||||||
|
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_tool_context_sync_func():
|
||||||
|
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context)
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_without_tool_context_sync_func():
|
||||||
|
"""Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_1_missing_arg_sync_func():
|
||||||
|
"""Test that run_async calls the function with 1 missing arg in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg2
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_1_missing_arg_async_func():
|
||||||
|
"""Test that run_async calls the function with 1 missing arg in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg2": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_3_missing_arg_sync_func():
|
||||||
|
"""Test that run_async calls the function with 3 missing args in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {"arg2": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg3
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_3_missing_arg_async_func():
|
||||||
|
"""Test that run_async calls the function with 3 missing args in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {"arg3": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg2
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_missing_all_arg_sync_func():
|
||||||
|
"""Test that run_async calls the function with all missing args in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg2
|
||||||
|
arg3
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_missing_all_arg_async_func():
|
||||||
|
"""Test that run_async calls the function with all missing args in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg2
|
||||||
|
arg3
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_optional_args_not_set_sync_func():
|
||||||
|
"""Test that run_async calls the function for sync funciton with optional args not set."""
|
||||||
|
|
||||||
|
def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs):
|
||||||
|
return f"{arg1},{arg3}"
|
||||||
|
|
||||||
|
tool = FunctionTool(func_with_optional_args)
|
||||||
|
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1,test_value_3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_optional_args_not_set_async_func():
|
||||||
|
"""Test that run_async calls the function for async funciton with optional args not set."""
|
||||||
|
|
||||||
|
async def async_func_with_optional_args(
|
||||||
|
arg1, arg2=None, *, arg3, arg4=None, **kwargs
|
||||||
|
):
|
||||||
|
return f"{arg1},{arg3}"
|
||||||
|
|
||||||
|
tool = FunctionTool(async_func_with_optional_args)
|
||||||
|
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1,test_value_3"
|
Loading…
Reference in New Issue
Block a user