mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
refactor: update callback type signatures to support sync and async responses
This commit is contained in:
parent
504aa6ba06
commit
fcbf57466e
@ -57,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
|
|||||||
]
|
]
|
||||||
BeforeToolCallback: TypeAlias = Callable[
|
BeforeToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext],
|
[BaseTool, dict[str, Any], ToolContext],
|
||||||
Awaitable[Optional[dict]],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
]
|
]
|
||||||
AfterToolCallback: TypeAlias = Callable[
|
AfterToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||||
Awaitable[Optional[dict]],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
]
|
]
|
||||||
|
|
||||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||||
|
@ -151,36 +151,33 @@ async def handle_function_calls_async(
|
|||||||
# do not use "args" as the variable name, because it is a reserved keyword
|
# do not use "args" as the variable name, because it is a reserved keyword
|
||||||
# in python debugger.
|
# in python debugger.
|
||||||
function_args = function_call.args or {}
|
function_args = function_call.args or {}
|
||||||
function_response = None
|
function_response: Optional[dict] = None
|
||||||
# # Calls the tool if before_tool_callback does not exist or returns None.
|
|
||||||
# if agent.before_tool_callback:
|
# before_tool_callback (sync or async)
|
||||||
# function_response = agent.before_tool_callback(
|
|
||||||
# tool=tool, args=function_args, tool_context=tool_context
|
|
||||||
# )
|
|
||||||
# Short-circuit via before_tool_callback (sync *or* async)
|
|
||||||
if agent.before_tool_callback:
|
if agent.before_tool_callback:
|
||||||
_maybe = agent.before_tool_callback(
|
function_response = agent.before_tool_callback(
|
||||||
tool=tool, args=function_args, tool_context=tool_context
|
tool=tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(_maybe):
|
if inspect.isawaitable(function_response):
|
||||||
_maybe = await _maybe
|
function_response = await function_response
|
||||||
function_response = _maybe
|
|
||||||
if not function_response:
|
if not function_response:
|
||||||
function_response = await __call_tool_async(
|
function_response = await __call_tool_async(
|
||||||
tool, args=function_args, tool_context=tool_context
|
tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
# Calls after_tool_callback if it exists.
|
|
||||||
|
# after_tool_callback (sync or async)
|
||||||
if agent.after_tool_callback:
|
if agent.after_tool_callback:
|
||||||
_maybe2 = agent.after_tool_callback(
|
altered_function_response = agent.after_tool_callback(
|
||||||
tool=tool,
|
tool=tool,
|
||||||
args=function_args,
|
args=function_args,
|
||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
tool_response=function_response,
|
tool_response=function_response,
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(_maybe2):
|
if inspect.isawaitable(altered_function_response):
|
||||||
_maybe2 = await _maybe2
|
altered_function_response = await altered_function_response
|
||||||
if _maybe2 is not None:
|
if altered_function_response is not None:
|
||||||
function_response = _maybe2
|
function_response = altered_function_response
|
||||||
|
|
||||||
if tool.is_long_running:
|
if tool.is_long_running:
|
||||||
# Allow long running function to return None to not provide function response.
|
# Allow long running function to return None to not provide function response.
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -27,81 +27,83 @@ from ... import utils
|
|||||||
|
|
||||||
|
|
||||||
class AsyncBeforeToolCallback:
|
class AsyncBeforeToolCallback:
|
||||||
def __init__(self, mock_response: Dict[str, Any]):
|
|
||||||
self.mock_response = mock_response
|
|
||||||
|
|
||||||
async def __call__(
|
def __init__(self, mock_response: Dict[str, Any]):
|
||||||
self,
|
self.mock_response = mock_response
|
||||||
tool: FunctionTool,
|
|
||||||
args: Dict[str, Any],
|
async def __call__(
|
||||||
tool_context: ToolContext,
|
self,
|
||||||
) -> Optional[Dict[str, Any]]:
|
tool: FunctionTool,
|
||||||
return self.mock_response
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.mock_response
|
||||||
|
|
||||||
|
|
||||||
class AsyncAfterToolCallback:
|
class AsyncAfterToolCallback:
|
||||||
def __init__(self, mock_response: Dict[str, Any]):
|
|
||||||
self.mock_response = mock_response
|
|
||||||
|
|
||||||
async def __call__(
|
def __init__(self, mock_response: Dict[str, Any]):
|
||||||
self,
|
self.mock_response = mock_response
|
||||||
tool: FunctionTool,
|
|
||||||
args: Dict[str, Any],
|
async def __call__(
|
||||||
tool_context: ToolContext,
|
self,
|
||||||
tool_response: Dict[str, Any],
|
tool: FunctionTool,
|
||||||
) -> Optional[Dict[str, Any]]:
|
args: Dict[str, Any],
|
||||||
return self.mock_response
|
tool_context: ToolContext,
|
||||||
|
tool_response: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.mock_response
|
||||||
|
|
||||||
|
|
||||||
async def invoke_tool_with_callbacks(
|
async def invoke_tool_with_callbacks(
|
||||||
before_cb=None, after_cb=None
|
before_cb=None, after_cb=None
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
def simple_fn(**kwargs) -> Dict[str, Any]:
|
def simple_fn(**kwargs) -> Dict[str, Any]:
|
||||||
return {"initial": "response"}
|
return {"initial": "response"}
|
||||||
|
|
||||||
tool = FunctionTool(simple_fn)
|
tool = FunctionTool(simple_fn)
|
||||||
model = utils.MockModel.create(responses=[])
|
model = utils.MockModel.create(responses=[])
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name="agent",
|
name="agent",
|
||||||
model=model,
|
model=model,
|
||||||
tools=[tool],
|
tools=[tool],
|
||||||
before_tool_callback=before_cb,
|
before_tool_callback=before_cb,
|
||||||
after_tool_callback=after_cb,
|
after_tool_callback=after_cb,
|
||||||
)
|
)
|
||||||
invocation_context = utils.create_invocation_context(
|
invocation_context = utils.create_invocation_context(
|
||||||
agent=agent, user_content=""
|
agent=agent, user_content=""
|
||||||
)
|
)
|
||||||
# Build function call event
|
# Build function call event
|
||||||
function_call = types.FunctionCall(name=tool.name, args={})
|
function_call = types.FunctionCall(name=tool.name, args={})
|
||||||
content = types.Content(parts=[types.Part(function_call=function_call)])
|
content = types.Content(parts=[types.Part(function_call=function_call)])
|
||||||
event = Event(
|
event = Event(
|
||||||
invocation_id=invocation_context.invocation_id,
|
invocation_id=invocation_context.invocation_id,
|
||||||
author=agent.name,
|
author=agent.name,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
tools_dict = {tool.name: tool}
|
tools_dict = {tool.name: tool}
|
||||||
return await handle_function_calls_async(
|
return await handle_function_calls_async(
|
||||||
invocation_context,
|
invocation_context,
|
||||||
event,
|
event,
|
||||||
tools_dict,
|
tools_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_before_tool_callback():
|
async def test_async_before_tool_callback():
|
||||||
mock_resp = {"test": "before_tool_callback"}
|
mock_resp = {"test": "before_tool_callback"}
|
||||||
before_cb = AsyncBeforeToolCallback(mock_resp)
|
before_cb = AsyncBeforeToolCallback(mock_resp)
|
||||||
result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
|
result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
|
||||||
assert result_event is not None
|
assert result_event is not None
|
||||||
part = result_event.content.parts[0]
|
part = result_event.content.parts[0]
|
||||||
assert part.function_response.response == mock_resp
|
assert part.function_response.response == mock_resp
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_after_tool_callback():
|
async def test_async_after_tool_callback():
|
||||||
mock_resp = {"test": "after_tool_callback"}
|
mock_resp = {"test": "after_tool_callback"}
|
||||||
after_cb = AsyncAfterToolCallback(mock_resp)
|
after_cb = AsyncAfterToolCallback(mock_resp)
|
||||||
result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
|
result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
|
||||||
assert result_event is not None
|
assert result_event is not None
|
||||||
part = result_event.content.parts[0]
|
part = result_event.content.parts[0]
|
||||||
assert part.function_response.response == mock_resp
|
assert part.function_response.response == mock_resp
|
||||||
|
Loading…
Reference in New Issue
Block a user