Currently if a model calls a FunctionTool without all the mandatory parameters, the code will just break. This change basically adds the capability for the FunctionTool to identify if the model is missing required arguments, and in that case, instead of breaking the execution, it provides a error message to the model so it could fix the request and retry.

PiperOrigin-RevId: 751023475
This commit is contained in:
Ankur Sharma 2025-04-24 09:31:04 -07:00 committed by Copybara-Service
parent e6109b1dd6
commit f872577f68
2 changed files with 280 additions and 0 deletions

View File

@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
if 'tool_context' in signature.parameters: if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context args_to_call['tool_context'] = tool_context
# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
# If the check fails, then we don't invoke the tool and let the Agent know
# that there was a missing a input parameter. This will basically help
# the underlying model fix the issue and retry.
mandatory_args = self._get_mandatory_args()
missing_mandatory_args = [
arg for arg in mandatory_args if arg not in args_to_call
]
if missing_mandatory_args:
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
{missing_mandatory_args_str}
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}
if inspect.iscoroutinefunction(self.func): if inspect.iscoroutinefunction(self.func):
return await self.func(**args_to_call) or {} return await self.func(**args_to_call) or {}
else: else:
@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
args_to_call['tool_context'] = tool_context args_to_call['tool_context'] = tool_context
async for item in self.func(**args_to_call): async for item in self.func(**args_to_call):
yield item yield item
def _get_mandatory_args(
self,
) -> list[str]:
"""Identifies mandatory parameters (those without default values) for a function.
Returns:
A list of strings, where each string is the name of a mandatory parameter.
"""
signature = inspect.signature(self.func)
mandatory_params = []
for name, param in signature.parameters.items():
# A parameter is mandatory if:
# 1. It has no default value (param.default is inspect.Parameter.empty)
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
#
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
if param.default == inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
mandatory_params.append(name)
return mandatory_params

View File

@ -0,0 +1,238 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from google.adk.tools.function_tool import FunctionTool
import pytest
def function_for_testing_with_no_args():
"""Function for testing with no args."""
pass
async def async_function_for_testing_with_1_arg_and_tool_context(
arg1, tool_context
):
"""Async function for testing with 1 arge and tool context."""
assert arg1
assert tool_context
return arg1
async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Async function for testing with 2 arge and no tool context."""
assert arg1
assert arg2
return arg1
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
"""Function for testing with 1 arge and tool context."""
assert arg1
assert tool_context
return arg1
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Function for testing with 2 arge and no tool context."""
assert arg1
assert arg2
return arg1
async def async_function_for_testing_with_4_arg_and_no_tool_context(
arg1, arg2, arg3, arg4
):
"""Async function for testing with 4 args."""
pass
def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4):
"""Function for testing with 4 args."""
pass
def test_init():
"""Test that the FunctionTool is initialized correctly."""
tool = FunctionTool(function_for_testing_with_no_args)
assert tool.name == "function_for_testing_with_no_args"
assert tool.description == "Function for testing with no args."
assert tool.func == function_for_testing_with_no_args
@pytest.mark.asyncio
async def test_run_async_with_tool_context_async_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_async_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_with_tool_context_sync_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_sync_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_1_missing_arg_sync_func():
"""Test that run_async calls the function with 1 missing arg in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg2
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_1_missing_arg_async_func():
"""Test that run_async calls the function with 1 missing arg in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_3_missing_arg_sync_func():
"""Test that run_async calls the function with 3 missing args in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_3_missing_arg_async_func():
"""Test that run_async calls the function with 3 missing args in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
args = {"arg3": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_missing_all_arg_sync_func():
"""Test that run_async calls the function with all missing args in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_missing_all_arg_async_func():
"""Test that run_async calls the function with all missing args in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_with_optional_args_not_set_sync_func():
"""Test that run_async calls the function for sync funciton with optional args not set."""
def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs):
return f"{arg1},{arg3}"
tool = FunctionTool(func_with_optional_args)
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"
@pytest.mark.asyncio
async def test_run_async_with_optional_args_not_set_async_func():
"""Test that run_async calls the function for async funciton with optional args not set."""
async def async_func_with_optional_args(
arg1, arg2=None, *, arg3, arg4=None, **kwargs
):
return f"{arg1},{arg3}"
tool = FunctionTool(async_func_with_optional_args)
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"