mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
fix: support Callable that has __call__ as coroutine function in FunctionTool
PiperOrigin-RevId: 760913537
This commit is contained in:
parent
5115474f2b
commit
f67ccf32c3
@ -33,7 +33,17 @@ class FunctionTool(BaseTool):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func: Callable[..., Any]):
|
def __init__(self, func: Callable[..., Any]):
|
||||||
super().__init__(name=func.__name__, description=func.__doc__)
|
"""Extract metadata from a callable object."""
|
||||||
|
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||||
|
# Handle regular functions and methods
|
||||||
|
name = func.__name__
|
||||||
|
doc = func.__doc__ or ''
|
||||||
|
else:
|
||||||
|
# Handle objects with __call__ method
|
||||||
|
call_method = func.__call__
|
||||||
|
name = func.__class__.__name__
|
||||||
|
doc = call_method.__doc__ or func.__doc__ or ''
|
||||||
|
super().__init__(name=name, description=doc)
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -76,7 +86,14 @@ class FunctionTool(BaseTool):
|
|||||||
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
return {'error': error_str}
|
return {'error': error_str}
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self.func):
|
# Functions are callable objects, but not all callable objects are functions
|
||||||
|
# checking coroutine function is not enough. We also need to check whether
|
||||||
|
# Callable's __call__ function is a coroutine funciton
|
||||||
|
if (
|
||||||
|
inspect.iscoroutinefunction(self.func)
|
||||||
|
or hasattr(self.func, '__call__')
|
||||||
|
and inspect.iscoroutinefunction(self.func.__call__)
|
||||||
|
):
|
||||||
return await self.func(**args_to_call) or {}
|
return await self.func(**args_to_call) or {}
|
||||||
else:
|
else:
|
||||||
return self.func(**args_to_call) or {}
|
return self.func(**args_to_call) or {}
|
||||||
|
@ -39,6 +39,14 @@ async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
|||||||
return arg1
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCallableWith2ArgsAndNoToolContext:
|
||||||
|
|
||||||
|
async def __call__(self, arg1, arg2):
|
||||||
|
assert arg1
|
||||||
|
assert arg2
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
|
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
|
||||||
"""Function for testing with 1 arge and tool context."""
|
"""Function for testing with 1 arge and tool context."""
|
||||||
assert arg1
|
assert arg1
|
||||||
@ -46,6 +54,14 @@ def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
|
|||||||
return arg1
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCallableWith1ArgAndToolContext:
|
||||||
|
|
||||||
|
async def __call__(self, arg1, tool_context):
|
||||||
|
assert arg1
|
||||||
|
assert tool_context
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
||||||
"""Function for testing with 2 arge and no tool context."""
|
"""Function for testing with 2 arge and no tool context."""
|
||||||
assert arg1
|
assert arg1
|
||||||
@ -83,6 +99,16 @@ async def test_run_async_with_tool_context_async_func():
|
|||||||
assert result == "test_value_1"
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_tool_context_async_callable():
|
||||||
|
"""Test that run_async calls the callable with tool_context when tool_context is in signature (async callable)."""
|
||||||
|
|
||||||
|
tool = FunctionTool(AsyncCallableWith1ArgAndToolContext())
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_without_tool_context_async_func():
|
async def test_run_async_without_tool_context_async_func():
|
||||||
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
|
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
|
||||||
@ -92,6 +118,15 @@ async def test_run_async_without_tool_context_async_func():
|
|||||||
assert result == "test_value_1"
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_without_tool_context_async_callable():
|
||||||
|
"""Test that run_async calls the callable without tool_context when tool_context is not in signature (async callable)."""
|
||||||
|
tool = FunctionTool(AsyncCallableWith2ArgsAndNoToolContext())
|
||||||
|
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_with_tool_context_sync_func():
|
async def test_run_async_with_tool_context_sync_func():
|
||||||
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
|
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
|
||||||
|
Loading…
Reference in New Issue
Block a user