fix:fix mcp toolset close issue

PiperOrigin-RevId: 759636772
This commit is contained in:
Xiang (Sean) Zhou
2025-05-16 09:05:18 -07:00
committed by Copybara-Service
parent 12507dc6cc
commit 05a853bc91
3 changed files with 294 additions and 23 deletions
+66 -8
View File
@@ -21,6 +21,7 @@ import json
import logging
import os
from pathlib import Path
import signal
import sys
import time
import traceback
@@ -221,7 +222,7 @@ def get_fast_api_app(
)
provider.add_span_processor(processor)
else:
logging.warning(
logger.warning(
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
" not be enabled."
)
@@ -232,14 +233,71 @@ def get_fast_api_app(
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
if lifespan:
async with lifespan(app) as lifespan_context:
yield
# Set up signal handlers for graceful shutdown
original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT)
for toolset in toolsets_to_close:
await toolset.close()
else:
yield
def cleanup_handler(sig, frame):
# Log the signal
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
# Do synchronous cleanup if needed
# Then call original handler if it exists
if sig == signal.SIGTERM and callable(original_sigterm):
original_sigterm(sig, frame)
elif sig == signal.SIGINT and callable(original_sigint):
original_sigint(sig, frame)
# Install cleanup handlers
signal.signal(signal.SIGTERM, cleanup_handler)
signal.signal(signal.SIGINT, cleanup_handler)
try:
if lifespan:
async with lifespan(app) as lifespan_context:
yield lifespan_context
else:
yield
finally:
# During shutdown, properly clean up all toolsets
logger.info(
"Server shutdown initiated, cleaning up %s toolsets",
len(toolsets_to_close),
)
# Create tasks for all toolset closures to run concurrently
cleanup_tasks = []
for toolset in toolsets_to_close:
task = asyncio.create_task(close_toolset_safely(toolset))
cleanup_tasks.append(task)
if cleanup_tasks:
# Wait for all cleanup tasks with timeout
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=10.0, # 10 second timeout for cleanup
return_when=asyncio.ALL_COMPLETED,
)
# If any tasks are still pending, log it
if pending:
logger.warn(
f"{len(pending)} toolset cleanup tasks didn't complete in time"
)
for task in pending:
task.cancel()
# Restore original signal handlers
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint)
async def close_toolset_safely(toolset):
"""Safely close a toolset with error handling."""
try:
logger.info(f"Closing toolset: {type(toolset).__name__}")
await toolset.close()
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
except Exception as e:
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)