Refactor function callback handling and update type signatures

- Simplify variable names in `functions.py`: always use `function_response` and `altered_function_response`
- Update LlmAgent callback type aliases to support async:
  - Import `Awaitable`
  - Change `BeforeToolCallback` and `AfterToolCallback` signatures to return `Awaitable[Optional[dict]]`
- Ensure `after_tool_callback` uses `await` when necessary
This commit is contained in:
Alankrit Verma 2025-04-28 23:16:43 -04:00
parent 21736067f9
commit 08ac9a117e
2 changed files with 18 additions and 24 deletions

View File

@ -15,12 +15,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from typing import AsyncGenerator
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from google.genai import types from google.genai import types
from pydantic import BaseModel from pydantic import BaseModel
@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
] ]
BeforeToolCallback: TypeAlias = Callable[ BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext], [BaseTool, dict[str, Any], ToolContext],
Optional[dict], Awaitable[Optional[dict]],
] ]
AfterToolCallback: TypeAlias = Callable[ AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict], [BaseTool, dict[str, Any], ToolContext, dict],
Optional[dict], Awaitable[Optional[dict]],
] ]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]

View File

@ -237,12 +237,11 @@ async def handle_function_calls_live(
# tool, function_args, tool_context # tool, function_args, tool_context
# ) # )
if agent.before_tool_callback: if agent.before_tool_callback:
_maybe = agent.before_tool_callback( function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context tool=tool, args=function_args, tool_context=tool_context
) )
if inspect.isawaitable(_maybe): if inspect.isawaitable(function_response):
_maybe = await _maybe function_response = await function_response
function_response = _maybe
if not function_response: if not function_response:
function_response = await _process_function_live_helper( function_response = await _process_function_live_helper(
@ -260,16 +259,16 @@ async def handle_function_calls_live(
# if new_response: # if new_response:
# function_response = new_response # function_response = new_response
if agent.after_tool_callback: if agent.after_tool_callback:
_maybe2 = agent.after_tool_callback( altered_function_response = agent.after_tool_callback(
tool=tool, tool=tool,
args=function_args, args=function_args,
tool_context=tool_context, tool_context=tool_context,
tool_response=function_response, tool_response=function_response,
) )
if inspect.isawaitable(_maybe2): if inspect.isawaitable(altered_function_response):
_maybe2 = await _maybe2 altered_function_response = await altered_function_response
if _maybe2 is not None: if altered_function_response is not None:
function_response = _maybe2 function_response = altered_function_response
if tool.is_long_running: if tool.is_long_running:
# Allow async function to return None to not provide function response. # Allow async function to return None to not provide function response.