From 08ac9a117e889af6611d64999b7b827ac81a2fcd Mon Sep 17 00:00:00 2001 From: Alankrit Verma Date: Mon, 28 Apr 2025 23:16:43 -0400 Subject: [PATCH] Refactor function callback handling and update type signatures - Simplify variable names in `functions.py`: always use `function_response` and `altered_function_response` - Update LlmAgent callback type aliases to support async: - Import `Awaitable` - Change `BeforeToolCallback` and `AfterToolCallback` signatures to return `Awaitable[Optional[dict]]` - Ensure `after_tool_callback` uses `await` when necessary --- src/google/adk/agents/llm_agent.py | 11 ++------ src/google/adk/flows/llm_flows/functions.py | 31 ++++++++++----------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index a140997..e6c7941 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -15,12 +15,7 @@ from __future__ import annotations import logging -from typing import Any -from typing import AsyncGenerator -from typing import Callable -from typing import Literal -from typing import Optional -from typing import Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union from google.genai import types from pydantic import BaseModel @@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[ ] BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], - Optional[dict], + Awaitable[Optional[dict]], ] AfterToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext, dict], - Optional[dict], + Awaitable[Optional[dict]], ] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 25f9db3..6d42c54 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -237,12 +237,11 @@ async def handle_function_calls_live( # tool, function_args, tool_context # ) if agent.before_tool_callback: - _maybe = agent.before_tool_callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(_maybe): - _maybe = await _maybe - function_response = _maybe + function_response = agent.before_tool_callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response if not function_response: function_response = await _process_function_live_helper( @@ -260,16 +259,16 @@ async def handle_function_calls_live( # if new_response: # function_response = new_response if agent.after_tool_callback: - _maybe2 = agent.after_tool_callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(_maybe2): - _maybe2 = await _maybe2 - if _maybe2 is not None: - function_response = _maybe2 + altered_function_response = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response if tool.is_long_running: # Allow async function to return None to not provide function response.