mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-04 13:56:24 -06:00
refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners
PiperOrigin-RevId: 762463028
This commit is contained in:
committed by
Copybara-Service
parent
b9b2c3fb54
commit
92c37496d3
@@ -16,12 +16,9 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
@@ -55,11 +52,9 @@ from starlette.types import Lifespan
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from ..evaluation.eval_case import EvalCase
|
||||
@@ -75,12 +70,12 @@ from ..sessions.session import Session
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
from ..tools.base_toolset import BaseToolset
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalCaseResult
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
from .cli_eval import EvalMetricResultPerInvocation
|
||||
from .cli_eval import EvalSetResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import cleanup
|
||||
from .utils import common
|
||||
from .utils import create_empty_state
|
||||
from .utils import envs
|
||||
@@ -230,27 +225,8 @@ def get_fast_api_app(
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
toolsets_to_close: set[BaseToolset] = set()
|
||||
|
||||
@asynccontextmanager
|
||||
async def internal_lifespan(app: FastAPI):
|
||||
# Set up signal handlers for graceful shutdown
|
||||
original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def cleanup_handler(sig, frame):
|
||||
# Log the signal
|
||||
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
|
||||
# Do synchronous cleanup if needed
|
||||
# Then call original handler if it exists
|
||||
if sig == signal.SIGTERM and callable(original_sigterm):
|
||||
original_sigterm(sig, frame)
|
||||
elif sig == signal.SIGINT and callable(original_sigint):
|
||||
original_sigint(sig, frame)
|
||||
|
||||
# Install cleanup handlers
|
||||
signal.signal(signal.SIGTERM, cleanup_handler)
|
||||
signal.signal(signal.SIGINT, cleanup_handler)
|
||||
|
||||
try:
|
||||
if lifespan:
|
||||
@@ -259,46 +235,8 @@ def get_fast_api_app(
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
# During shutdown, properly clean up all toolsets
|
||||
logger.info(
|
||||
"Server shutdown initiated, cleaning up %s toolsets",
|
||||
len(toolsets_to_close),
|
||||
)
|
||||
|
||||
# Create tasks for all toolset closures to run concurrently
|
||||
cleanup_tasks = []
|
||||
for toolset in toolsets_to_close:
|
||||
task = asyncio.create_task(close_toolset_safely(toolset))
|
||||
cleanup_tasks.append(task)
|
||||
|
||||
if cleanup_tasks:
|
||||
# Wait for all cleanup tasks with timeout
|
||||
done, pending = await asyncio.wait(
|
||||
cleanup_tasks,
|
||||
timeout=10.0, # 10 second timeout for cleanup
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
|
||||
# If any tasks are still pending, log it
|
||||
if pending:
|
||||
logger.warning(
|
||||
f"{len(pending)} toolset cleanup tasks didn't complete in time"
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Restore original signal handlers
|
||||
signal.signal(signal.SIGTERM, original_sigterm)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
async def close_toolset_safely(toolset):
|
||||
"""Safely close a toolset with error handling."""
|
||||
try:
|
||||
logger.info(f"Closing toolset: {type(toolset).__name__}")
|
||||
await toolset.close()
|
||||
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
|
||||
# Create tasks for all runner closures to run concurrently
|
||||
await cleanup.close_runners(list(runner_dict.values()))
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI(lifespan=internal_lifespan)
|
||||
@@ -903,16 +841,6 @@ def get_fast_api_app(
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
|
||||
toolsets = set()
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool_union in agent.tools:
|
||||
if isinstance(tool_union, BaseToolset):
|
||||
toolsets.add(tool_union)
|
||||
for sub_agent in agent.sub_agents:
|
||||
toolsets.update(_get_all_toolsets(sub_agent))
|
||||
return toolsets
|
||||
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
@@ -924,7 +852,6 @@ def get_fast_api_app(
|
||||
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
toolsets_to_close.update(_get_all_toolsets(root_agent))
|
||||
return root_agent
|
||||
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ...runners import Runner
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
async def close_runners(runners: List[Runner]) -> None:
|
||||
cleanup_tasks = [asyncio.create_task(runner.close()) for runner in runners]
|
||||
if cleanup_tasks:
|
||||
# Wait for all cleanup tasks with timeout
|
||||
done, pending = await asyncio.wait(
|
||||
cleanup_tasks,
|
||||
timeout=30.0, # 30 second timeout for cleanup
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
|
||||
# If any tasks are still pending, log it
|
||||
if pending:
|
||||
logger.warning(
|
||||
"%s runner close tasks didn't complete in time", len(pending)
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
Reference in New Issue
Block a user