diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 15b9c6b..96ffe67 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -59,6 +59,23 @@ class FunctionTool(BaseTool): if 'tool_context' in signature.parameters: args_to_call['tool_context'] = tool_context + # Before invoking the function, we check for if the list of args passed in + # has all the mandatory arguments or not. + # If the check fails, then we don't invoke the tool and let the Agent know + # that there was a missing a input parameter. This will basically help + # the underlying model fix the issue and retry. + mandatory_args = self._get_mandatory_args() + missing_mandatory_args = [ + arg for arg in mandatory_args if arg not in args_to_call + ] + + if missing_mandatory_args: + missing_mandatory_args_str = '\n'.join(missing_mandatory_args) + error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present: +{missing_mandatory_args_str} +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + return {'error': error_str} + if inspect.iscoroutinefunction(self.func): return await self.func(**args_to_call) or {} else: @@ -85,3 +102,28 @@ class FunctionTool(BaseTool): args_to_call['tool_context'] = tool_context async for item in self.func(**args_to_call): yield item + + def _get_mandatory_args( + self, + ) -> list[str]: + """Identifies mandatory parameters (those without default values) for a function. + + Returns: + A list of strings, where each string is the name of a mandatory parameter. + """ + signature = inspect.signature(self.func) + mandatory_params = [] + + for name, param in signature.parameters.items(): + # A parameter is mandatory if: + # 1. It has no default value (param.default is inspect.Parameter.empty) + # 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter + # + # For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind + if param.default == inspect.Parameter.empty and param.kind not in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + mandatory_params.append(name) + + return mandatory_params diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py new file mode 100644 index 0000000..60c432f --- /dev/null +++ b/tests/unittests/tools/test_function_tool.py @@ -0,0 +1,238 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.tools.function_tool import FunctionTool +import pytest + + +def function_for_testing_with_no_args(): + """Function for testing with no args.""" + pass + + +async def async_function_for_testing_with_1_arg_and_tool_context( + arg1, tool_context +): + """Async function for testing with 1 arge and tool context.""" + assert arg1 + assert tool_context + return arg1 + + +async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2): + """Async function for testing with 2 arge and no tool context.""" + assert arg1 + assert arg2 + return arg1 + + +def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context): + """Function for testing with 1 arge and tool context.""" + assert arg1 + assert tool_context + return arg1 + + +def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2): + """Function for testing with 2 arge and no tool context.""" + assert arg1 + assert arg2 + return arg1 + + +async def async_function_for_testing_with_4_arg_and_no_tool_context( + arg1, arg2, arg3, arg4 +): + """Async function for testing with 4 args.""" + pass + + +def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4): + """Function for testing with 4 args.""" + pass + + +def test_init(): + """Test that the FunctionTool is initialized correctly.""" + tool = FunctionTool(function_for_testing_with_no_args) + assert tool.name == "function_for_testing_with_no_args" + assert tool.description == "Function for testing with no args." + assert tool.func == function_for_testing_with_no_args + + +@pytest.mark.asyncio +async def test_run_async_with_tool_context_async_func(): + """Test that run_async calls the function with tool_context when tool_context is in signature (async function).""" + + tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context) + args = {"arg1": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1" + + +@pytest.mark.asyncio +async def test_run_async_without_tool_context_async_func(): + """Test that run_async calls the function without tool_context when tool_context is not in signature (async function).""" + tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context) + args = {"arg1": "test_value_1", "arg2": "test_value_2"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1" + + +@pytest.mark.asyncio +async def test_run_async_with_tool_context_sync_func(): + """Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function).""" + tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context) + args = {"arg1": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1" + + +@pytest.mark.asyncio +async def test_run_async_without_tool_context_sync_func(): + """Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function).""" + tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context) + args = {"arg1": "test_value_1", "arg2": "test_value_2"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1" + + +@pytest.mark.asyncio +async def test_run_async_1_missing_arg_sync_func(): + """Test that run_async calls the function with 1 missing arg in signature (synchronous function).""" + tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context) + args = {"arg1": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == { + "error": ( + """Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present: +arg2 +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + ) + } + + +@pytest.mark.asyncio +async def test_run_async_1_missing_arg_async_func(): + """Test that run_async calls the function with 1 missing arg in signature (async function).""" + tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context) + args = {"arg2": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == { + "error": ( + """Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present: +arg1 +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + ) + } + + +@pytest.mark.asyncio +async def test_run_async_3_missing_arg_sync_func(): + """Test that run_async calls the function with 3 missing args in signature (synchronous function).""" + tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context) + args = {"arg2": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == { + "error": ( + """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present: +arg1 +arg3 +arg4 +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + ) + } + + +@pytest.mark.asyncio +async def test_run_async_3_missing_arg_async_func(): + """Test that run_async calls the function with 3 missing args in signature (async function).""" + tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context) + args = {"arg3": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == { + "error": ( + """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present: +arg1 +arg2 +arg4 +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + ) + } + + +@pytest.mark.asyncio +async def test_run_async_missing_all_arg_sync_func(): + """Test that run_async calls the function with all missing args in signature (synchronous function).""" + tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context) + args = {} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == { + "error": ( + """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present: +arg1 +arg2 +arg3 +arg4 +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + ) + } + + +@pytest.mark.asyncio +async def test_run_async_missing_all_arg_async_func(): + """Test that run_async calls the function with all missing args in signature (async function).""" + tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context) + args = {} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == { + "error": ( + """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present: +arg1 +arg2 +arg3 +arg4 +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + ) + } + + +@pytest.mark.asyncio +async def test_run_async_with_optional_args_not_set_sync_func(): + """Test that run_async calls the function for sync funciton with optional args not set.""" + + def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs): + return f"{arg1},{arg3}" + + tool = FunctionTool(func_with_optional_args) + args = {"arg1": "test_value_1", "arg3": "test_value_3"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1,test_value_3" + + +@pytest.mark.asyncio +async def test_run_async_with_optional_args_not_set_async_func(): + """Test that run_async calls the function for async funciton with optional args not set.""" + + async def async_func_with_optional_args( + arg1, arg2=None, *, arg3, arg4=None, **kwargs + ): + return f"{arg1},{arg3}" + + tool = FunctionTool(async_func_with_optional_args) + args = {"arg1": "test_value_1", "arg3": "test_value_3"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1,test_value_3"