diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 96ffe67..b1cd8e4 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -33,7 +33,17 @@ class FunctionTool(BaseTool): """ def __init__(self, func: Callable[..., Any]): - super().__init__(name=func.__name__, description=func.__doc__) + """Extract metadata from a callable object.""" + if inspect.isfunction(func) or inspect.ismethod(func): + # Handle regular functions and methods + name = func.__name__ + doc = func.__doc__ or '' + else: + # Handle objects with __call__ method + call_method = func.__call__ + name = func.__class__.__name__ + doc = call_method.__doc__ or func.__doc__ or '' + super().__init__(name=name, description=doc) self.func = func @override @@ -76,7 +86,14 @@ class FunctionTool(BaseTool): You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" return {'error': error_str} - if inspect.iscoroutinefunction(self.func): + # Functions are callable objects, but not all callable objects are functions + # checking coroutine function is not enough. We also need to check whether + # Callable's __call__ function is a coroutine funciton + if ( + inspect.iscoroutinefunction(self.func) + or hasattr(self.func, '__call__') + and inspect.iscoroutinefunction(self.func.__call__) + ): return await self.func(**args_to_call) or {} else: return self.func(**args_to_call) or {} diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 60c432f..950e388 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -39,6 +39,14 @@ async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2): return arg1 +class AsyncCallableWith2ArgsAndNoToolContext: + + async def __call__(self, arg1, arg2): + assert arg1 + assert arg2 + return arg1 + + def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context): """Function for testing with 1 arge and tool context.""" assert arg1 @@ -46,6 +54,14 @@ def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context): return arg1 +class AsyncCallableWith1ArgAndToolContext: + + async def __call__(self, arg1, tool_context): + assert arg1 + assert tool_context + return arg1 + + def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2): """Function for testing with 2 arge and no tool context.""" assert arg1 @@ -83,6 +99,16 @@ async def test_run_async_with_tool_context_async_func(): assert result == "test_value_1" +@pytest.mark.asyncio +async def test_run_async_with_tool_context_async_callable(): + """Test that run_async calls the callable with tool_context when tool_context is in signature (async callable).""" + + tool = FunctionTool(AsyncCallableWith1ArgAndToolContext()) + args = {"arg1": "test_value_1"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1" + + @pytest.mark.asyncio async def test_run_async_without_tool_context_async_func(): """Test that run_async calls the function without tool_context when tool_context is not in signature (async function).""" @@ -92,6 +118,15 @@ async def test_run_async_without_tool_context_async_func(): assert result == "test_value_1" +@pytest.mark.asyncio +async def test_run_async_without_tool_context_async_callable(): + """Test that run_async calls the callable without tool_context when tool_context is not in signature (async callable).""" + tool = FunctionTool(AsyncCallableWith2ArgsAndNoToolContext()) + args = {"arg1": "test_value_1", "arg2": "test_value_2"} + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert result == "test_value_1" + + @pytest.mark.asyncio async def test_run_async_with_tool_context_sync_func(): """Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""