Refactor function callback handling and update type signatures

- Simplify variable names in `functions.py`: always use `function_response` and `altered_function_response`
- Update LlmAgent callback type aliases to support async:
  - Import `Awaitable`
  - Change `BeforeToolCallback` and `AfterToolCallback` signatures to return `Awaitable[Optional[dict]]`
- Ensure `after_tool_callback` uses `await` when necessary
This commit is contained in:
Alankrit Verma 2025-04-28 23:16:43 -04:00
parent 21736067f9
commit 08ac9a117e
2 changed files with 18 additions and 24 deletions

View File

@ -15,12 +15,7 @@
from __future__ import annotations
import logging
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from google.genai import types
from pydantic import BaseModel
@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Optional[dict],
Awaitable[Optional[dict]],
]
AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict],
Optional[dict],
Awaitable[Optional[dict]],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]

View File

@ -237,12 +237,11 @@ async def handle_function_calls_live(
# tool, function_args, tool_context
# )
if agent.before_tool_callback:
_maybe = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(_maybe):
_maybe = await _maybe
function_response = _maybe
function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response:
function_response = await _process_function_live_helper(
@ -260,16 +259,16 @@ async def handle_function_calls_live(
# if new_response:
# function_response = new_response
if agent.after_tool_callback:
_maybe2 = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if inspect.isawaitable(_maybe2):
_maybe2 = await _maybe2
if _maybe2 is not None:
function_response = _maybe2
altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.