mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00

-- ad923c2c8c503ba73c62db695e88f1a3ea1aeeea by YU MING HSU <abego452@gmail.com>: docs: enhance Contribution process within CONTRIBUTING.md -- 8022924fb7e975ac278d38fce3b5fd593d874536 by YU MING HSU <abego452@gmail.com>: fix: move _maybe_append_user_content from google_llm.py to base_llm.py, so subclass can get benefit from it, call _maybe_append_user_content from generate_content_async within lite_llm.py -- cf891fb1a3bbccaaf9d0055b23f614ce52449977 by YU MING HSU <abego452@gmail.com>: fix: modify install dependencies cmd, and use pyink to format codebase COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/428 from hsuyuming:fix_litellm_error_issue_427 dbec4949798e6399a0410d1b8ba7cc6a7cad7bdd PiperOrigin-RevId: 754124679
509 lines
17 KiB
Python
509 lines
17 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Handles function callings for LLM flow."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import inspect
|
|
import logging
|
|
from typing import Any
|
|
from typing import AsyncGenerator
|
|
from typing import cast
|
|
from typing import Optional
|
|
import uuid
|
|
|
|
from google.genai import types
|
|
|
|
from ...agents.active_streaming_tool import ActiveStreamingTool
|
|
from ...agents.invocation_context import InvocationContext
|
|
from ...auth.auth_tool import AuthToolArguments
|
|
from ...events.event import Event
|
|
from ...events.event_actions import EventActions
|
|
from ...telemetry import trace_tool_call
|
|
from ...telemetry import trace_tool_response
|
|
from ...telemetry import tracer
|
|
from ...tools.base_tool import BaseTool
|
|
from ...tools.tool_context import ToolContext
|
|
|
|
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
|
|
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def generate_client_function_call_id() -> str:
|
|
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
|
|
|
|
|
|
def populate_client_function_call_id(model_response_event: Event) -> None:
|
|
if not model_response_event.get_function_calls():
|
|
return
|
|
for function_call in model_response_event.get_function_calls():
|
|
if not function_call.id:
|
|
function_call.id = generate_client_function_call_id()
|
|
|
|
|
|
def remove_client_function_call_id(content: types.Content) -> None:
|
|
if content and content.parts:
|
|
for part in content.parts:
|
|
if (
|
|
part.function_call
|
|
and part.function_call.id
|
|
and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
|
|
):
|
|
part.function_call.id = None
|
|
if (
|
|
part.function_response
|
|
and part.function_response.id
|
|
and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
|
|
):
|
|
part.function_response.id = None
|
|
|
|
|
|
def get_long_running_function_calls(
|
|
function_calls: list[types.FunctionCall],
|
|
tools_dict: dict[str, BaseTool],
|
|
) -> set[str]:
|
|
long_running_tool_ids = set()
|
|
for function_call in function_calls:
|
|
if (
|
|
function_call.name in tools_dict
|
|
and tools_dict[function_call.name].is_long_running
|
|
):
|
|
long_running_tool_ids.add(function_call.id)
|
|
|
|
return long_running_tool_ids
|
|
|
|
|
|
def generate_auth_event(
|
|
invocation_context: InvocationContext,
|
|
function_response_event: Event,
|
|
) -> Optional[Event]:
|
|
if not function_response_event.actions.requested_auth_configs:
|
|
return None
|
|
parts = []
|
|
long_running_tool_ids = set()
|
|
for (
|
|
function_call_id,
|
|
auth_config,
|
|
) in function_response_event.actions.requested_auth_configs.items():
|
|
|
|
request_euc_function_call = types.FunctionCall(
|
|
name=REQUEST_EUC_FUNCTION_CALL_NAME,
|
|
args=AuthToolArguments(
|
|
function_call_id=function_call_id,
|
|
auth_config=auth_config,
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
request_euc_function_call.id = generate_client_function_call_id()
|
|
long_running_tool_ids.add(request_euc_function_call.id)
|
|
parts.append(types.Part(function_call=request_euc_function_call))
|
|
|
|
return Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=invocation_context.agent.name,
|
|
branch=invocation_context.branch,
|
|
content=types.Content(
|
|
parts=parts, role=function_response_event.content.role
|
|
),
|
|
long_running_tool_ids=long_running_tool_ids,
|
|
)
|
|
|
|
|
|
async def handle_function_calls_async(
|
|
invocation_context: InvocationContext,
|
|
function_call_event: Event,
|
|
tools_dict: dict[str, BaseTool],
|
|
filters: Optional[set[str]] = None,
|
|
) -> Optional[Event]:
|
|
"""Calls the functions and returns the function response event."""
|
|
from ...agents.llm_agent import LlmAgent
|
|
|
|
agent = invocation_context.agent
|
|
if not isinstance(agent, LlmAgent):
|
|
return
|
|
|
|
function_calls = function_call_event.get_function_calls()
|
|
|
|
function_response_events: list[Event] = []
|
|
for function_call in function_calls:
|
|
if filters and function_call.id not in filters:
|
|
continue
|
|
tool, tool_context = _get_tool_and_context(
|
|
invocation_context,
|
|
function_call_event,
|
|
function_call,
|
|
tools_dict,
|
|
)
|
|
# do not use "args" as the variable name, because it is a reserved keyword
|
|
# in python debugger.
|
|
function_args = function_call.args or {}
|
|
function_response: Optional[dict] = None
|
|
|
|
# before_tool_callback (sync or async)
|
|
if agent.before_tool_callback:
|
|
function_response = agent.before_tool_callback(
|
|
tool=tool, args=function_args, tool_context=tool_context
|
|
)
|
|
if inspect.isawaitable(function_response):
|
|
function_response = await function_response
|
|
|
|
if not function_response:
|
|
function_response = await __call_tool_async(
|
|
tool, args=function_args, tool_context=tool_context
|
|
)
|
|
|
|
# after_tool_callback (sync or async)
|
|
if agent.after_tool_callback:
|
|
altered_function_response = agent.after_tool_callback(
|
|
tool=tool,
|
|
args=function_args,
|
|
tool_context=tool_context,
|
|
tool_response=function_response,
|
|
)
|
|
if inspect.isawaitable(altered_function_response):
|
|
altered_function_response = await altered_function_response
|
|
if altered_function_response is not None:
|
|
function_response = altered_function_response
|
|
|
|
if tool.is_long_running:
|
|
# Allow long running function to return None to not provide function response.
|
|
if not function_response:
|
|
continue
|
|
|
|
# Builds the function response event.
|
|
function_response_event = __build_response_event(
|
|
tool, function_response, tool_context, invocation_context
|
|
)
|
|
function_response_events.append(function_response_event)
|
|
|
|
if not function_response_events:
|
|
return None
|
|
merged_event = merge_parallel_function_response_events(
|
|
function_response_events
|
|
)
|
|
if len(function_response_events) > 1:
|
|
# this is needed for debug traces of parallel calls
|
|
# individual response with tool.name is traced in __build_response_event
|
|
# (we drop tool.name from span name here as this is merged event)
|
|
with tracer.start_as_current_span('tool_response'):
|
|
trace_tool_response(
|
|
invocation_context=invocation_context,
|
|
event_id=merged_event.id,
|
|
function_response_event=merged_event,
|
|
)
|
|
return merged_event
|
|
|
|
|
|
async def handle_function_calls_live(
|
|
invocation_context: InvocationContext,
|
|
function_call_event: Event,
|
|
tools_dict: dict[str, BaseTool],
|
|
) -> Event:
|
|
"""Calls the functions and returns the function response event."""
|
|
from ...agents.llm_agent import LlmAgent
|
|
|
|
agent = cast(LlmAgent, invocation_context.agent)
|
|
function_calls = function_call_event.get_function_calls()
|
|
|
|
function_response_events: list[Event] = []
|
|
for function_call in function_calls:
|
|
tool, tool_context = _get_tool_and_context(
|
|
invocation_context, function_call_event, function_call, tools_dict
|
|
)
|
|
# do not use "args" as the variable name, because it is a reserved keyword
|
|
# in python debugger.
|
|
function_args = function_call.args or {}
|
|
function_response = None
|
|
# # Calls the tool if before_tool_callback does not exist or returns None.
|
|
# if agent.before_tool_callback:
|
|
# function_response = agent.before_tool_callback(
|
|
# tool, function_args, tool_context
|
|
# )
|
|
if agent.before_tool_callback:
|
|
function_response = agent.before_tool_callback(
|
|
tool=tool, args=function_args, tool_context=tool_context
|
|
)
|
|
if inspect.isawaitable(function_response):
|
|
function_response = await function_response
|
|
|
|
if not function_response:
|
|
function_response = await _process_function_live_helper(
|
|
tool, tool_context, function_call, function_args, invocation_context
|
|
)
|
|
|
|
# Calls after_tool_callback if it exists.
|
|
# if agent.after_tool_callback:
|
|
# new_response = agent.after_tool_callback(
|
|
# tool,
|
|
# function_args,
|
|
# tool_context,
|
|
# function_response,
|
|
# )
|
|
# if new_response:
|
|
# function_response = new_response
|
|
if agent.after_tool_callback:
|
|
altered_function_response = agent.after_tool_callback(
|
|
tool=tool,
|
|
args=function_args,
|
|
tool_context=tool_context,
|
|
tool_response=function_response,
|
|
)
|
|
if inspect.isawaitable(altered_function_response):
|
|
altered_function_response = await altered_function_response
|
|
if altered_function_response is not None:
|
|
function_response = altered_function_response
|
|
|
|
if tool.is_long_running:
|
|
# Allow async function to return None to not provide function response.
|
|
if not function_response:
|
|
continue
|
|
|
|
# Builds the function response event.
|
|
function_response_event = __build_response_event(
|
|
tool, function_response, tool_context, invocation_context
|
|
)
|
|
function_response_events.append(function_response_event)
|
|
|
|
if not function_response_events:
|
|
return None
|
|
merged_event = merge_parallel_function_response_events(
|
|
function_response_events
|
|
)
|
|
return merged_event
|
|
|
|
|
|
async def _process_function_live_helper(
|
|
tool, tool_context, function_call, function_args, invocation_context
|
|
):
|
|
function_response = None
|
|
# Check if this is a stop_streaming function call
|
|
if (
|
|
function_call.name == 'stop_streaming'
|
|
and 'function_name' in function_args
|
|
):
|
|
function_name = function_args['function_name']
|
|
active_tasks = invocation_context.active_streaming_tools
|
|
if (
|
|
function_name in active_tasks
|
|
and active_tasks[function_name].task
|
|
and not active_tasks[function_name].task.done()
|
|
):
|
|
task = active_tasks[function_name].task
|
|
task.cancel()
|
|
try:
|
|
# Wait for the task to be cancelled
|
|
await asyncio.wait_for(task, timeout=1.0)
|
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
# Log the specific condition
|
|
if task.cancelled():
|
|
logging.info(f'Task {function_name} was cancelled successfully')
|
|
elif task.done():
|
|
logging.info(f'Task {function_name} completed during cancellation')
|
|
else:
|
|
logging.warning(
|
|
f'Task {function_name} might still be running after'
|
|
' cancellation timeout'
|
|
)
|
|
function_response = {
|
|
'status': f'The task is not cancelled yet for {function_name}.'
|
|
}
|
|
if not function_response:
|
|
# Clean up the reference
|
|
active_tasks[function_name].task = None
|
|
|
|
function_response = {
|
|
'status': f'Successfully stopped streaming function {function_name}'
|
|
}
|
|
else:
|
|
function_response = {
|
|
'status': f'No active streaming function named {function_name} found'
|
|
}
|
|
elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
|
|
# for streaming tool use case
|
|
# we require the function to be a async generator function
|
|
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
|
try:
|
|
async for result in __call_tool_live(
|
|
tool=tool,
|
|
args=function_args,
|
|
tool_context=tool_context,
|
|
invocation_context=invocation_context,
|
|
):
|
|
updated_content = types.Content(
|
|
role='user',
|
|
parts=[
|
|
types.Part.from_text(
|
|
text=f'Function {tool.name} returned: {result}'
|
|
)
|
|
],
|
|
)
|
|
invocation_context.live_request_queue.send_content(updated_content)
|
|
except asyncio.CancelledError:
|
|
raise # Re-raise to properly propagate the cancellation
|
|
|
|
task = asyncio.create_task(
|
|
run_tool_and_update_queue(tool, function_args, tool_context)
|
|
)
|
|
if invocation_context.active_streaming_tools is None:
|
|
invocation_context.active_streaming_tools = {}
|
|
if tool.name in invocation_context.active_streaming_tools:
|
|
invocation_context.active_streaming_tools[tool.name].task = task
|
|
else:
|
|
invocation_context.active_streaming_tools[tool.name] = (
|
|
ActiveStreamingTool(task=task)
|
|
)
|
|
# Immediately return a pending response.
|
|
# This is required by current live model.
|
|
function_response = {
|
|
'status': (
|
|
'The function is running asynchronously and the results are'
|
|
' pending.'
|
|
)
|
|
}
|
|
else:
|
|
function_response = await __call_tool_async(
|
|
tool, args=function_args, tool_context=tool_context
|
|
)
|
|
return function_response
|
|
|
|
|
|
def _get_tool_and_context(
|
|
invocation_context: InvocationContext,
|
|
function_call_event: Event,
|
|
function_call: types.FunctionCall,
|
|
tools_dict: dict[str, BaseTool],
|
|
):
|
|
if function_call.name not in tools_dict:
|
|
raise ValueError(
|
|
f'Function {function_call.name} is not found in the tools_dict.'
|
|
)
|
|
|
|
tool_context = ToolContext(
|
|
invocation_context=invocation_context,
|
|
function_call_id=function_call.id,
|
|
)
|
|
|
|
tool = tools_dict[function_call.name]
|
|
|
|
return (tool, tool_context)
|
|
|
|
|
|
async def __call_tool_live(
|
|
tool: BaseTool,
|
|
args: dict[str, object],
|
|
tool_context: ToolContext,
|
|
invocation_context: InvocationContext,
|
|
) -> AsyncGenerator[Event, None]:
|
|
"""Calls the tool asynchronously (awaiting the coroutine)."""
|
|
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
|
|
trace_tool_call(args=args)
|
|
async for item in tool._call_live(
|
|
args=args,
|
|
tool_context=tool_context,
|
|
invocation_context=invocation_context,
|
|
):
|
|
yield item
|
|
|
|
|
|
async def __call_tool_async(
|
|
tool: BaseTool,
|
|
args: dict[str, Any],
|
|
tool_context: ToolContext,
|
|
) -> Any:
|
|
"""Calls the tool."""
|
|
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
|
|
trace_tool_call(args=args)
|
|
return await tool.run_async(args=args, tool_context=tool_context)
|
|
|
|
|
|
def __build_response_event(
|
|
tool: BaseTool,
|
|
function_result: dict[str, object],
|
|
tool_context: ToolContext,
|
|
invocation_context: InvocationContext,
|
|
) -> Event:
|
|
with tracer.start_as_current_span(f'tool_response [{tool.name}]'):
|
|
# Specs requires the result to be a dict.
|
|
if not isinstance(function_result, dict):
|
|
function_result = {'result': function_result}
|
|
|
|
part_function_response = types.Part.from_function_response(
|
|
name=tool.name, response=function_result
|
|
)
|
|
part_function_response.function_response.id = tool_context.function_call_id
|
|
|
|
content = types.Content(
|
|
role='user',
|
|
parts=[part_function_response],
|
|
)
|
|
|
|
function_response_event = Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author=invocation_context.agent.name,
|
|
content=content,
|
|
actions=tool_context.actions,
|
|
branch=invocation_context.branch,
|
|
)
|
|
|
|
trace_tool_response(
|
|
invocation_context=invocation_context,
|
|
event_id=function_response_event.id,
|
|
function_response_event=function_response_event,
|
|
)
|
|
return function_response_event
|
|
|
|
|
|
def merge_parallel_function_response_events(
|
|
function_response_events: list['Event'],
|
|
) -> 'Event':
|
|
if not function_response_events:
|
|
raise ValueError('No function response events provided.')
|
|
|
|
if len(function_response_events) == 1:
|
|
return function_response_events[0]
|
|
merged_parts = []
|
|
for event in function_response_events:
|
|
if event.content:
|
|
for part in event.content.parts or []:
|
|
merged_parts.append(part)
|
|
|
|
# Use the first event as the "base" for common attributes
|
|
base_event = function_response_events[0]
|
|
|
|
# Merge actions from all events
|
|
|
|
merged_actions = EventActions()
|
|
merged_requested_auth_configs = {}
|
|
for event in function_response_events:
|
|
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
|
|
merged_actions = merged_actions.model_copy(
|
|
update=event.actions.model_dump()
|
|
)
|
|
merged_actions.requested_auth_configs = merged_requested_auth_configs
|
|
# Create the new merged event
|
|
merged_event = Event(
|
|
invocation_id=Event.new_id(),
|
|
author=base_event.author,
|
|
branch=base_event.branch,
|
|
content=types.Content(role='user', parts=merged_parts),
|
|
actions=merged_actions, # Optionally merge actions if required
|
|
)
|
|
|
|
# Use the base_event as the timestamp
|
|
merged_event.timestamp = base_event.timestamp
|
|
return merged_event
|