Currently if a model calls a FunctionTool without all the mandatory parameters, the code will just break. This change basically adds the capability for the FunctionTool to identify if the model is missing required arguments, and in that case, instead of breaking the execution, it provides a error message to the model so it could fix the request and retry.

PiperOrigin-RevId: 751023475
This commit is contained in:
Ankur Sharma
2025-04-24 09:31:04 -07:00
committed by Copybara-Service
parent e6109b1dd6
commit f872577f68
2 changed files with 280 additions and 0 deletions

View File

@@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context
# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
# If the check fails, then we don't invoke the tool and let the Agent know
# that there was a missing a input parameter. This will basically help
# the underlying model fix the issue and retry.
mandatory_args = self._get_mandatory_args()
missing_mandatory_args = [
arg for arg in mandatory_args if arg not in args_to_call
]
if missing_mandatory_args:
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
{missing_mandatory_args_str}
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}
if inspect.iscoroutinefunction(self.func):
return await self.func(**args_to_call) or {}
else:
@@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
args_to_call['tool_context'] = tool_context
async for item in self.func(**args_to_call):
yield item
def _get_mandatory_args(
self,
) -> list[str]:
"""Identifies mandatory parameters (those without default values) for a function.
Returns:
A list of strings, where each string is the name of a mandatory parameter.
"""
signature = inspect.signature(self.func)
mandatory_params = []
for name, param in signature.parameters.items():
# A parameter is mandatory if:
# 1. It has no default value (param.default is inspect.Parameter.empty)
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
#
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
if param.default == inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
mandatory_params.append(name)
return mandatory_params