fix: support Callable that has __call__ as coroutine function in FunctionTool

PiperOrigin-RevId: 760913537
This commit is contained in:
Xiang (Sean) Zhou 2025-05-19 22:09:08 -07:00 committed by Copybara-Service
parent 5115474f2b
commit f67ccf32c3
2 changed files with 54 additions and 2 deletions

View File

@ -33,7 +33,17 @@ class FunctionTool(BaseTool):
""" """
def __init__(self, func: Callable[..., Any]): def __init__(self, func: Callable[..., Any]):
super().__init__(name=func.__name__, description=func.__doc__) """Extract metadata from a callable object."""
if inspect.isfunction(func) or inspect.ismethod(func):
# Handle regular functions and methods
name = func.__name__
doc = func.__doc__ or ''
else:
# Handle objects with __call__ method
call_method = func.__call__
name = func.__class__.__name__
doc = call_method.__doc__ or func.__doc__ or ''
super().__init__(name=name, description=doc)
self.func = func self.func = func
@override @override
@ -76,7 +86,14 @@ class FunctionTool(BaseTool):
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str} return {'error': error_str}
if inspect.iscoroutinefunction(self.func): # Functions are callable objects, but not all callable objects are functions
# checking coroutine function is not enough. We also need to check whether
# Callable's __call__ function is a coroutine funciton
if (
inspect.iscoroutinefunction(self.func)
or hasattr(self.func, '__call__')
and inspect.iscoroutinefunction(self.func.__call__)
):
return await self.func(**args_to_call) or {} return await self.func(**args_to_call) or {}
else: else:
return self.func(**args_to_call) or {} return self.func(**args_to_call) or {}

View File

@ -39,6 +39,14 @@ async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
return arg1 return arg1
class AsyncCallableWith2ArgsAndNoToolContext:
async def __call__(self, arg1, arg2):
assert arg1
assert arg2
return arg1
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context): def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
"""Function for testing with 1 arge and tool context.""" """Function for testing with 1 arge and tool context."""
assert arg1 assert arg1
@ -46,6 +54,14 @@ def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
return arg1 return arg1
class AsyncCallableWith1ArgAndToolContext:
async def __call__(self, arg1, tool_context):
assert arg1
assert tool_context
return arg1
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2): def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Function for testing with 2 arge and no tool context.""" """Function for testing with 2 arge and no tool context."""
assert arg1 assert arg1
@ -83,6 +99,16 @@ async def test_run_async_with_tool_context_async_func():
assert result == "test_value_1" assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_with_tool_context_async_callable():
"""Test that run_async calls the callable with tool_context when tool_context is in signature (async callable)."""
tool = FunctionTool(AsyncCallableWith1ArgAndToolContext())
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_without_tool_context_async_func(): async def test_run_async_without_tool_context_async_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function).""" """Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
@ -92,6 +118,15 @@ async def test_run_async_without_tool_context_async_func():
assert result == "test_value_1" assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_async_callable():
"""Test that run_async calls the callable without tool_context when tool_context is not in signature (async callable)."""
tool = FunctionTool(AsyncCallableWith2ArgsAndNoToolContext())
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_with_tool_context_sync_func(): async def test_run_async_with_tool_context_sync_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function).""" """Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""