refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners

PiperOrigin-RevId: 762463028
This commit is contained in:
Xiang (Sean) Zhou
2025-05-23 09:48:48 -07:00
committed by Copybara-Service
parent b9b2c3fb54
commit 92c37496d3
6 changed files with 239 additions and 406 deletions
+3 -76
View File
@@ -16,12 +16,9 @@
import asyncio
from contextlib import asynccontextmanager
import importlib
import inspect
import json
import logging
import os
from pathlib import Path
import signal
import sys
import time
import traceback
@@ -55,11 +52,9 @@ from starlette.types import Lifespan
from typing_extensions import override
from ..agents import RunConfig
from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.llm_agent import LlmAgent
from ..agents.run_config import StreamingMode
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..evaluation.eval_case import EvalCase
@@ -75,12 +70,12 @@ from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService
from ..tools.base_toolset import BaseToolset
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalCaseResult
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
from .cli_eval import EvalMetricResultPerInvocation
from .cli_eval import EvalSetResult
from .cli_eval import EvalStatus
from .utils import cleanup
from .utils import common
from .utils import create_empty_state
from .utils import envs
@@ -230,27 +225,8 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
toolsets_to_close: set[BaseToolset] = set()
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
# Set up signal handlers for graceful shutdown
original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT)
def cleanup_handler(sig, frame):
# Log the signal
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
# Do synchronous cleanup if needed
# Then call original handler if it exists
if sig == signal.SIGTERM and callable(original_sigterm):
original_sigterm(sig, frame)
elif sig == signal.SIGINT and callable(original_sigint):
original_sigint(sig, frame)
# Install cleanup handlers
signal.signal(signal.SIGTERM, cleanup_handler)
signal.signal(signal.SIGINT, cleanup_handler)
try:
if lifespan:
@@ -259,46 +235,8 @@ def get_fast_api_app(
else:
yield
finally:
# During shutdown, properly clean up all toolsets
logger.info(
"Server shutdown initiated, cleaning up %s toolsets",
len(toolsets_to_close),
)
# Create tasks for all toolset closures to run concurrently
cleanup_tasks = []
for toolset in toolsets_to_close:
task = asyncio.create_task(close_toolset_safely(toolset))
cleanup_tasks.append(task)
if cleanup_tasks:
# Wait for all cleanup tasks with timeout
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=10.0, # 10 second timeout for cleanup
return_when=asyncio.ALL_COMPLETED,
)
# If any tasks are still pending, log it
if pending:
logger.warning(
f"{len(pending)} toolset cleanup tasks didn't complete in time"
)
for task in pending:
task.cancel()
# Restore original signal handlers
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint)
async def close_toolset_safely(toolset):
"""Safely close a toolset with error handling."""
try:
logger.info(f"Closing toolset: {type(toolset).__name__}")
await toolset.close()
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
except Exception as e:
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
# Create tasks for all runner closures to run concurrently
await cleanup.close_runners(list(runner_dict.values()))
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)
@@ -903,16 +841,6 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
for tool_union in agent.tools:
if isinstance(tool_union, BaseToolset):
toolsets.add(tool_union)
for sub_agent in agent.sub_agents:
toolsets.update(_get_all_toolsets(sub_agent))
return toolsets
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
@@ -924,7 +852,6 @@ def get_fast_api_app(
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
root_agent_dict[app_name] = root_agent
toolsets_to_close.update(_get_all_toolsets(root_agent))
return root_agent
async def _get_runner_async(app_name: str) -> Runner:
+40
View File
@@ -0,0 +1,40 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import List
from ...runners import Runner
logger = logging.getLogger("google_adk." + __name__)
async def close_runners(runners: List[Runner]) -> None:
cleanup_tasks = [asyncio.create_task(runner.close()) for runner in runners]
if cleanup_tasks:
# Wait for all cleanup tasks with timeout
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=30.0, # 30 second timeout for cleanup
return_when=asyncio.ALL_COMPLETED,
)
# If any tasks are still pending, log it
if pending:
logger.warning(
"%s runner close tasks didn't complete in time", len(pending)
)
for task in pending:
task.cancel()