diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index b1cd8e4..069108c 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -34,15 +34,27 @@ class FunctionTool(BaseTool): def __init__(self, func: Callable[..., Any]): """Extract metadata from a callable object.""" - if inspect.isfunction(func) or inspect.ismethod(func): - # Handle regular functions and methods + name = '' + doc = '' + # Handle different types of callables + if hasattr(func, '__name__'): + # Regular functions, unbound methods, etc. name = func.__name__ - doc = func.__doc__ or '' - else: - # Handle objects with __call__ method - call_method = func.__call__ + elif hasattr(func, '__class__'): + # Callable objects, bound methods, etc. name = func.__class__.__name__ - doc = call_method.__doc__ or func.__doc__ or '' + + # Get documentation (prioritize direct __doc__ if available) + if hasattr(func, '__doc__') and func.__doc__: + doc = func.__doc__ + elif ( + hasattr(func, '__call__') + and hasattr(func.__call__, '__doc__') + and func.__call__.__doc__ + ): + # For callable objects, try to get docstring from __call__ method + doc = func.__call__.__doc__ + super().__init__(name=name, description=doc) self.func = func diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 950e388..ece1f53 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -41,6 +41,10 @@ async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2): class AsyncCallableWith2ArgsAndNoToolContext: + def __init__(self): + self.__name__ = "Async callable name" + self.__doc__ = "Async callable doc" + async def __call__(self, arg1, arg2): assert arg1 assert arg2 @@ -57,6 +61,7 @@ def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context): class AsyncCallableWith1ArgAndToolContext: async def __call__(self, arg1, tool_context): + """Async call doc""" assert arg1 assert tool_context return arg1 @@ -107,6 +112,8 @@ async def test_run_async_with_tool_context_async_callable(): args = {"arg1": "test_value_1"} result = await tool.run_async(args=args, tool_context=MagicMock()) assert result == "test_value_1" + assert tool.name == "AsyncCallableWith1ArgAndToolContext" + assert tool.description == "Async call doc" @pytest.mark.asyncio @@ -125,6 +132,8 @@ async def test_run_async_without_tool_context_async_callable(): args = {"arg1": "test_value_1", "arg2": "test_value_2"} result = await tool.run_async(args=args, tool_context=MagicMock()) assert result == "test_value_1" + assert tool.name == "Async callable name" + assert tool.description == "Async callable doc" @pytest.mark.asyncio