fix:fix mcp toolset close issue

PiperOrigin-RevId: 759636772
This commit is contained in:
Xiang (Sean) Zhou 2025-05-16 09:05:18 -07:00 committed by Copybara-Service
parent 12507dc6cc
commit 05a853bc91
3 changed files with 294 additions and 23 deletions

View File

@ -21,6 +21,7 @@ import json
import logging
import os
from pathlib import Path
import signal
import sys
import time
import traceback
@ -221,7 +222,7 @@ def get_fast_api_app(
)
provider.add_span_processor(processor)
else:
logging.warning(
logger.warning(
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
" not be enabled."
)
@ -232,14 +233,71 @@ def get_fast_api_app(
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
if lifespan:
async with lifespan(app) as lifespan_context:
yield
# Set up signal handlers for graceful shutdown
original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT)
for toolset in toolsets_to_close:
await toolset.close()
else:
yield
def cleanup_handler(sig, frame):
# Log the signal
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
# Do synchronous cleanup if needed
# Then call original handler if it exists
if sig == signal.SIGTERM and callable(original_sigterm):
original_sigterm(sig, frame)
elif sig == signal.SIGINT and callable(original_sigint):
original_sigint(sig, frame)
# Install cleanup handlers
signal.signal(signal.SIGTERM, cleanup_handler)
signal.signal(signal.SIGINT, cleanup_handler)
try:
if lifespan:
async with lifespan(app) as lifespan_context:
yield lifespan_context
else:
yield
finally:
# During shutdown, properly clean up all toolsets
logger.info(
"Server shutdown initiated, cleaning up %s toolsets",
len(toolsets_to_close),
)
# Create tasks for all toolset closures to run concurrently
cleanup_tasks = []
for toolset in toolsets_to_close:
task = asyncio.create_task(close_toolset_safely(toolset))
cleanup_tasks.append(task)
if cleanup_tasks:
# Wait for all cleanup tasks with timeout
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=10.0, # 10 second timeout for cleanup
return_when=asyncio.ALL_COMPLETED,
)
# If any tasks are still pending, log it
if pending:
logger.warn(
f"{len(pending)} toolset cleanup tasks didn't complete in time"
)
for task in pending:
task.cancel()
# Restore original signal handlers
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint)
async def close_toolset_safely(toolset):
"""Safely close a toolset with error handling."""
try:
logger.info(f"Closing toolset: {type(toolset).__name__}")
await toolset.close()
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
except Exception as e:
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)

View File

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AsyncExitStack
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
import functools
import logging
import sys
from typing import Any, TextIO
from typing import Any, Optional, TextIO
import anyio
from pydantic import BaseModel
@ -34,6 +36,8 @@ except ImportError as e:
else:
raise e
logger = logging.getLogger(__name__)
class SseServerParams(BaseModel):
"""Parameters for the MCP SSE connection.
@ -108,6 +112,45 @@ def retry_on_closed_resource(async_reinit_func_name: str):
return decorator
@asynccontextmanager
async def tracked_stdio_client(server, errlog, process=None):
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
our_process = process
# If no process was provided, create one
if our_process is None:
our_process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Use the original stdio_client, but ensure process cleanup
try:
async with stdio_client(server=server, errlog=errlog) as client:
yield client, our_process
finally:
# Ensure the process is properly terminated if it still exists
if our_process and our_process.returncode is None:
try:
logger.info(
f'Terminating process {our_process.pid} from tracked_stdio_client'
)
our_process.terminate()
try:
await asyncio.wait_for(our_process.wait(), timeout=3.0)
except asyncio.TimeoutError:
# Force kill if it doesn't terminate quickly
if our_process.returncode is None:
logger.warning(f'Forcing kill of process {our_process.pid}')
our_process.kill()
except ProcessLookupError:
# Process already gone, that's fine
logger.info(f'Process {our_process.pid} already terminated')
class MCPSessionManager:
"""Manages MCP client sessions.
@ -138,25 +181,39 @@ class MCPSessionManager:
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self._connection_params = connection_params
self._exit_stack = exit_stack
self._errlog = errlog
self._process = None # Track the subprocess
self._active_processes = set() # Track all processes created
self._active_file_handles = set() # Track file handles
async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session(
async def create_session(
self,
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Creates a new MCP session and tracks the associated process."""
session, process = await self._initialize_session(
connection_params=self._connection_params,
exit_stack=self._exit_stack,
errlog=self._errlog,
)
self._process = process # Store reference to process
# Track the process
if process:
self._active_processes.add(process)
return session, process
@classmethod
async def initialize_session(
async def _initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Initializes an MCP client session.
Args:
@ -168,9 +225,17 @@ class MCPSessionManager:
Returns:
ClientSession: The initialized MCP client session.
"""
process = None
if isinstance(connection_params, StdioServerParameters):
client = stdio_client(server=connection_params, errlog=errlog)
# For stdio connections, we need to track the subprocess
client, process = await cls._create_stdio_client(
server=connection_params,
errlog=errlog,
exit_stack=exit_stack,
)
elif isinstance(connection_params, SseServerParams):
# For SSE connections, create the client without a subprocess
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
@ -184,7 +249,74 @@ class MCPSessionManager:
f' {connection_params}'
)
# Create the session with the client
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session
return session, process
@staticmethod
async def _create_stdio_client(
server: StdioServerParameters,
errlog: TextIO,
exit_stack: AsyncExitStack,
) -> tuple[Any, asyncio.subprocess.Process]:
"""Create stdio client and return both the client and process.
This implementation adapts to how the MCP stdio_client is created.
The actual implementation may need to be adjusted based on the MCP library
structure.
"""
# Create the subprocess directly so we can track it
process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Create the stdio client using the MCP library
try:
# Method 1: Try using the existing process if stdio_client supports it
client = stdio_client(server=server, errlog=errlog, process=process)
except TypeError:
# Method 2: If the above doesn't work, let stdio_client create its own process
# and we'll need to terminate both processes later
logger.warning(
'Using stdio_client with its own process - may lead to duplicate'
' processes'
)
client = stdio_client(server=server, errlog=errlog)
return client, process
async def _emergency_cleanup(self):
"""Perform emergency cleanup of resources when normal cleanup fails."""
logger.info('Performing emergency cleanup of MCPSessionManager resources')
# Clean up any tracked processes
for proc in list(self._active_processes):
try:
if proc and proc.returncode is None:
logger.info(f'Emergency termination of process {proc.pid}')
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=1.0)
except asyncio.TimeoutError:
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
proc.kill()
self._active_processes.remove(proc)
except Exception as e:
logger.error(f'Error during process cleanup: {e}')
# Clean up any tracked file handles
for handle in list(self._active_file_handles):
try:
if not handle.closed:
logger.info('Closing file handle')
handle.close()
self._active_file_handles.remove(handle)
except Exception as e:
logger.error(f'Error closing file handle: {e}')

View File

@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import AsyncExitStack
import logging
import os
import signal
import sys
from typing import List, Union
from typing import Optional
@ -39,14 +43,16 @@ except ImportError as e:
if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
" version."
) from e
else:
raise e
from .mcp_tool import MCPTool
logger = logging.getLogger(__name__)
class MCPToolset(BaseToolset):
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@ -82,10 +88,12 @@ class MCPToolset(BaseToolset):
"""
if not connection_params:
raise ValueError('Missing connection params in MCPToolset.')
raise ValueError("Missing connection params in MCPToolset.")
self._connection_params = connection_params
self._errlog = errlog
self._exit_stack = AsyncExitStack()
self._creator_task_id = None
self._process_pid = None # Store the subprocess PID
self._session_manager = MCPSessionManager(
connection_params=self._connection_params,
@ -94,10 +102,17 @@ class MCPToolset(BaseToolset):
)
self._session = None
self.tool_filter = tool_filter
self._initialized = False
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
self._session = await self._session_manager.create_session()
# Store the current task ID when initializing
self._creator_task_id = id(asyncio.current_task())
self._session, process = await self._session_manager.create_session()
# Store the process PID if available
if process and hasattr(process, "pid"):
self._process_pid = process.pid
self._initialized = True
return self._session
def _is_selected(
@ -114,10 +129,76 @@ class MCPToolset(BaseToolset):
@override
async def close(self):
"""Closes the connection to MCP Server."""
await self._exit_stack.aclose()
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
if not self._initialized:
return # Nothing to close
@retry_on_closed_resource('_initialize')
logger.info("Closing MCP Toolset")
# Step 1: Try graceful shutdown of the session if it exists
if self._session:
try:
logger.info("Attempting graceful session shutdown")
await self._session.shutdown()
except Exception as e:
logger.warning(f"Session shutdown error (continuing cleanup): {e}")
# Step 2: Try to close the exit stack
try:
logger.info("Closing AsyncExitStack")
await self._exit_stack.aclose()
# If we get here, the exit stack closed successfully
logger.info("AsyncExitStack closed successfully")
return
except RuntimeError as e:
if "Attempted to exit cancel scope in a different task" in str(e):
logger.warning("Task mismatch during shutdown - using fallback cleanup")
# Continue to manual cleanup
else:
logger.error(f"Unexpected RuntimeError: {e}")
# Continue to manual cleanup
except Exception as e:
logger.error(f"Error during exit stack closure: {e}")
# Continue to manual cleanup
# Step 3: Manual cleanup of the subprocess if we have its PID
if self._process_pid:
await self._ensure_process_terminated(self._process_pid)
# Step 4: Ask the session manager to do any additional cleanup it can
await self._session_manager._emergency_cleanup()
async def _ensure_process_terminated(self, pid):
"""Ensure a process is terminated using its PID."""
try:
# Check if process exists
os.kill(pid, 0) # This just checks if the process exists
logger.info(f"Terminating process with PID {pid}")
# First try SIGTERM for graceful shutdown
os.kill(pid, signal.SIGTERM)
# Give it a moment to terminate
for _ in range(30): # wait up to 3 seconds
await asyncio.sleep(0.1)
try:
os.kill(pid, 0) # Process still exists
except ProcessLookupError:
logger.info(f"Process {pid} terminated successfully")
return
# If we get here, process didn't terminate gracefully
logger.warning(
f"Process {pid} didn't terminate gracefully, using SIGKILL"
)
os.kill(pid, signal.SIGKILL)
except ProcessLookupError:
logger.info(f"Process {pid} already terminated")
except Exception as e:
logger.error(f"Error terminating process {pid}: {e}")
@retry_on_closed_resource("_initialize")
@override
async def get_tools(
self,