diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 0b6d608..87ab4e8 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -21,6 +21,7 @@ import json import logging import os from pathlib import Path +import signal import sys import time import traceback @@ -221,7 +222,7 @@ def get_fast_api_app( ) provider.add_span_processor(processor) else: - logging.warning( + logger.warning( "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will" " not be enabled." ) @@ -232,14 +233,71 @@ def get_fast_api_app( @asynccontextmanager async def internal_lifespan(app: FastAPI): - if lifespan: - async with lifespan(app) as lifespan_context: - yield + # Set up signal handlers for graceful shutdown + original_sigterm = signal.getsignal(signal.SIGTERM) + original_sigint = signal.getsignal(signal.SIGINT) - for toolset in toolsets_to_close: - await toolset.close() - else: - yield + def cleanup_handler(sig, frame): + # Log the signal + logger.info("Received signal %s, performing pre-shutdown cleanup", sig) + # Do synchronous cleanup if needed + # Then call original handler if it exists + if sig == signal.SIGTERM and callable(original_sigterm): + original_sigterm(sig, frame) + elif sig == signal.SIGINT and callable(original_sigint): + original_sigint(sig, frame) + + # Install cleanup handlers + signal.signal(signal.SIGTERM, cleanup_handler) + signal.signal(signal.SIGINT, cleanup_handler) + + try: + if lifespan: + async with lifespan(app) as lifespan_context: + yield lifespan_context + else: + yield + finally: + # During shutdown, properly clean up all toolsets + logger.info( + "Server shutdown initiated, cleaning up %s toolsets", + len(toolsets_to_close), + ) + + # Create tasks for all toolset closures to run concurrently + cleanup_tasks = [] + for toolset in toolsets_to_close: + task = asyncio.create_task(close_toolset_safely(toolset)) + cleanup_tasks.append(task) + + if cleanup_tasks: + # Wait for all cleanup tasks with timeout + done, pending = await asyncio.wait( + cleanup_tasks, + timeout=10.0, # 10 second timeout for cleanup + return_when=asyncio.ALL_COMPLETED, + ) + + # If any tasks are still pending, log it + if pending: + logger.warn( + f"{len(pending)} toolset cleanup tasks didn't complete in time" + ) + for task in pending: + task.cancel() + + # Restore original signal handlers + signal.signal(signal.SIGTERM, original_sigterm) + signal.signal(signal.SIGINT, original_sigint) + + async def close_toolset_safely(toolset): + """Safely close a toolset with error handling.""" + try: + logger.info(f"Closing toolset: {type(toolset).__name__}") + await toolset.close() + logger.info(f"Successfully closed toolset: {type(toolset).__name__}") + except Exception as e: + logger.error(f"Error closing toolset {type(toolset).__name__}: {e}") # Run the FastAPI server. app = FastAPI(lifespan=internal_lifespan) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index abc49da..2c6729c 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AsyncExitStack +import asyncio +from contextlib import AsyncExitStack, asynccontextmanager import functools +import logging import sys -from typing import Any, TextIO +from typing import Any, Optional, TextIO import anyio from pydantic import BaseModel @@ -34,6 +36,8 @@ except ImportError as e: else: raise e +logger = logging.getLogger(__name__) + class SseServerParams(BaseModel): """Parameters for the MCP SSE connection. @@ -108,6 +112,45 @@ def retry_on_closed_resource(async_reinit_func_name: str): return decorator +@asynccontextmanager +async def tracked_stdio_client(server, errlog, process=None): + """A wrapper around stdio_client that ensures proper process tracking and cleanup.""" + our_process = process + + # If no process was provided, create one + if our_process is None: + our_process = await asyncio.create_subprocess_exec( + server.command, + *server.args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=errlog, + ) + + # Use the original stdio_client, but ensure process cleanup + try: + async with stdio_client(server=server, errlog=errlog) as client: + yield client, our_process + finally: + # Ensure the process is properly terminated if it still exists + if our_process and our_process.returncode is None: + try: + logger.info( + f'Terminating process {our_process.pid} from tracked_stdio_client' + ) + our_process.terminate() + try: + await asyncio.wait_for(our_process.wait(), timeout=3.0) + except asyncio.TimeoutError: + # Force kill if it doesn't terminate quickly + if our_process.returncode is None: + logger.warning(f'Forcing kill of process {our_process.pid}') + our_process.kill() + except ProcessLookupError: + # Process already gone, that's fine + logger.info(f'Process {our_process.pid} already terminated') + + class MCPSessionManager: """Manages MCP client sessions. @@ -138,25 +181,39 @@ class MCPSessionManager: errlog: (Optional) TextIO stream for error logging. Use only for initializing a local stdio MCP session. """ + self._connection_params = connection_params self._exit_stack = exit_stack self._errlog = errlog + self._process = None # Track the subprocess + self._active_processes = set() # Track all processes created + self._active_file_handles = set() # Track file handles - async def create_session(self) -> ClientSession: - return await MCPSessionManager.initialize_session( + async def create_session( + self, + ) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]: + """Creates a new MCP session and tracks the associated process.""" + session, process = await self._initialize_session( connection_params=self._connection_params, exit_stack=self._exit_stack, errlog=self._errlog, ) + self._process = process # Store reference to process + + # Track the process + if process: + self._active_processes.add(process) + + return session, process @classmethod - async def initialize_session( + async def _initialize_session( cls, *, connection_params: StdioServerParameters | SseServerParams, exit_stack: AsyncExitStack, errlog: TextIO = sys.stderr, - ) -> ClientSession: + ) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]: """Initializes an MCP client session. Args: @@ -168,9 +225,17 @@ class MCPSessionManager: Returns: ClientSession: The initialized MCP client session. """ + process = None + if isinstance(connection_params, StdioServerParameters): - client = stdio_client(server=connection_params, errlog=errlog) + # For stdio connections, we need to track the subprocess + client, process = await cls._create_stdio_client( + server=connection_params, + errlog=errlog, + exit_stack=exit_stack, + ) elif isinstance(connection_params, SseServerParams): + # For SSE connections, create the client without a subprocess client = sse_client( url=connection_params.url, headers=connection_params.headers, @@ -184,7 +249,74 @@ class MCPSessionManager: f' {connection_params}' ) + # Create the session with the client transports = await exit_stack.enter_async_context(client) session = await exit_stack.enter_async_context(ClientSession(*transports)) await session.initialize() - return session + + return session, process + + @staticmethod + async def _create_stdio_client( + server: StdioServerParameters, + errlog: TextIO, + exit_stack: AsyncExitStack, + ) -> tuple[Any, asyncio.subprocess.Process]: + """Create stdio client and return both the client and process. + + This implementation adapts to how the MCP stdio_client is created. + The actual implementation may need to be adjusted based on the MCP library + structure. + """ + # Create the subprocess directly so we can track it + process = await asyncio.create_subprocess_exec( + server.command, + *server.args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=errlog, + ) + + # Create the stdio client using the MCP library + try: + # Method 1: Try using the existing process if stdio_client supports it + client = stdio_client(server=server, errlog=errlog, process=process) + except TypeError: + # Method 2: If the above doesn't work, let stdio_client create its own process + # and we'll need to terminate both processes later + logger.warning( + 'Using stdio_client with its own process - may lead to duplicate' + ' processes' + ) + client = stdio_client(server=server, errlog=errlog) + + return client, process + + async def _emergency_cleanup(self): + """Perform emergency cleanup of resources when normal cleanup fails.""" + logger.info('Performing emergency cleanup of MCPSessionManager resources') + + # Clean up any tracked processes + for proc in list(self._active_processes): + try: + if proc and proc.returncode is None: + logger.info(f'Emergency termination of process {proc.pid}') + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=1.0) + except asyncio.TimeoutError: + logger.warning(f"Process {proc.pid} didn't terminate, forcing kill") + proc.kill() + self._active_processes.remove(proc) + except Exception as e: + logger.error(f'Error during process cleanup: {e}') + + # Clean up any tracked file handles + for handle in list(self._active_file_handles): + try: + if not handle.closed: + logger.info('Closing file handle') + handle.close() + self._active_file_handles.remove(handle) + except Exception as e: + logger.error(f'Error closing file handle: {e}') diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index e4793cc..e93fbdf 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from contextlib import AsyncExitStack +import logging +import os +import signal import sys from typing import List, Union from typing import Optional @@ -39,14 +43,16 @@ except ImportError as e: if sys.version_info < (3, 10): raise ImportError( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' + "MCP Tool requires Python 3.10 or above. Please upgrade your Python" + " version." ) from e else: raise e from .mcp_tool import MCPTool +logger = logging.getLogger(__name__) + class MCPToolset(BaseToolset): """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. @@ -82,10 +88,12 @@ class MCPToolset(BaseToolset): """ if not connection_params: - raise ValueError('Missing connection params in MCPToolset.') + raise ValueError("Missing connection params in MCPToolset.") self._connection_params = connection_params self._errlog = errlog self._exit_stack = AsyncExitStack() + self._creator_task_id = None + self._process_pid = None # Store the subprocess PID self._session_manager = MCPSessionManager( connection_params=self._connection_params, @@ -94,10 +102,17 @@ class MCPToolset(BaseToolset): ) self._session = None self.tool_filter = tool_filter + self._initialized = False async def _initialize(self) -> ClientSession: """Connects to the MCP Server and initializes the ClientSession.""" - self._session = await self._session_manager.create_session() + # Store the current task ID when initializing + self._creator_task_id = id(asyncio.current_task()) + self._session, process = await self._session_manager.create_session() + # Store the process PID if available + if process and hasattr(process, "pid"): + self._process_pid = process.pid + self._initialized = True return self._session def _is_selected( @@ -114,10 +129,76 @@ class MCPToolset(BaseToolset): @override async def close(self): - """Closes the connection to MCP Server.""" - await self._exit_stack.aclose() + """Safely closes the connection to MCP Server with guaranteed resource cleanup.""" + if not self._initialized: + return # Nothing to close - @retry_on_closed_resource('_initialize') + logger.info("Closing MCP Toolset") + + # Step 1: Try graceful shutdown of the session if it exists + if self._session: + try: + logger.info("Attempting graceful session shutdown") + await self._session.shutdown() + except Exception as e: + logger.warning(f"Session shutdown error (continuing cleanup): {e}") + + # Step 2: Try to close the exit stack + try: + logger.info("Closing AsyncExitStack") + await self._exit_stack.aclose() + # If we get here, the exit stack closed successfully + logger.info("AsyncExitStack closed successfully") + return + except RuntimeError as e: + if "Attempted to exit cancel scope in a different task" in str(e): + logger.warning("Task mismatch during shutdown - using fallback cleanup") + # Continue to manual cleanup + else: + logger.error(f"Unexpected RuntimeError: {e}") + # Continue to manual cleanup + except Exception as e: + logger.error(f"Error during exit stack closure: {e}") + # Continue to manual cleanup + + # Step 3: Manual cleanup of the subprocess if we have its PID + if self._process_pid: + await self._ensure_process_terminated(self._process_pid) + + # Step 4: Ask the session manager to do any additional cleanup it can + await self._session_manager._emergency_cleanup() + + async def _ensure_process_terminated(self, pid): + """Ensure a process is terminated using its PID.""" + try: + # Check if process exists + os.kill(pid, 0) # This just checks if the process exists + + logger.info(f"Terminating process with PID {pid}") + # First try SIGTERM for graceful shutdown + os.kill(pid, signal.SIGTERM) + + # Give it a moment to terminate + for _ in range(30): # wait up to 3 seconds + await asyncio.sleep(0.1) + try: + os.kill(pid, 0) # Process still exists + except ProcessLookupError: + logger.info(f"Process {pid} terminated successfully") + return + + # If we get here, process didn't terminate gracefully + logger.warning( + f"Process {pid} didn't terminate gracefully, using SIGKILL" + ) + os.kill(pid, signal.SIGKILL) + + except ProcessLookupError: + logger.info(f"Process {pid} already terminated") + except Exception as e: + logger.error(f"Error terminating process {pid}: {e}") + + @retry_on_closed_resource("_initialize") @override async def get_tools( self,