diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 51cbf25..8effed6 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -16,12 +16,9 @@ import asyncio from contextlib import asynccontextmanager import importlib -import inspect -import json import logging import os from pathlib import Path -import signal import sys import time import traceback @@ -55,11 +52,9 @@ from starlette.types import Lifespan from typing_extensions import override from ..agents import RunConfig -from ..agents.base_agent import BaseAgent from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequestQueue from ..agents.llm_agent import Agent -from ..agents.llm_agent import LlmAgent from ..agents.run_config import StreamingMode from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..evaluation.eval_case import EvalCase @@ -75,12 +70,12 @@ from ..sessions.session import Session from ..sessions.vertex_ai_session_service import VertexAiSessionService from ..tools.base_toolset import BaseToolset from .cli_eval import EVAL_SESSION_ID_PREFIX -from .cli_eval import EvalCaseResult from .cli_eval import EvalMetric from .cli_eval import EvalMetricResult from .cli_eval import EvalMetricResultPerInvocation from .cli_eval import EvalSetResult from .cli_eval import EvalStatus +from .utils import cleanup from .utils import common from .utils import create_empty_state from .utils import envs @@ -230,27 +225,8 @@ def get_fast_api_app( trace.set_tracer_provider(provider) - toolsets_to_close: set[BaseToolset] = set() - @asynccontextmanager async def internal_lifespan(app: FastAPI): - # Set up signal handlers for graceful shutdown - original_sigterm = signal.getsignal(signal.SIGTERM) - original_sigint = signal.getsignal(signal.SIGINT) - - def cleanup_handler(sig, frame): - # Log the signal - logger.info("Received signal %s, performing pre-shutdown cleanup", sig) - # Do synchronous cleanup if needed - # Then call original handler if it exists - if sig == signal.SIGTERM and callable(original_sigterm): - original_sigterm(sig, frame) - elif sig == signal.SIGINT and callable(original_sigint): - original_sigint(sig, frame) - - # Install cleanup handlers - signal.signal(signal.SIGTERM, cleanup_handler) - signal.signal(signal.SIGINT, cleanup_handler) try: if lifespan: @@ -259,46 +235,8 @@ def get_fast_api_app( else: yield finally: - # During shutdown, properly clean up all toolsets - logger.info( - "Server shutdown initiated, cleaning up %s toolsets", - len(toolsets_to_close), - ) - - # Create tasks for all toolset closures to run concurrently - cleanup_tasks = [] - for toolset in toolsets_to_close: - task = asyncio.create_task(close_toolset_safely(toolset)) - cleanup_tasks.append(task) - - if cleanup_tasks: - # Wait for all cleanup tasks with timeout - done, pending = await asyncio.wait( - cleanup_tasks, - timeout=10.0, # 10 second timeout for cleanup - return_when=asyncio.ALL_COMPLETED, - ) - - # If any tasks are still pending, log it - if pending: - logger.warning( - f"{len(pending)} toolset cleanup tasks didn't complete in time" - ) - for task in pending: - task.cancel() - - # Restore original signal handlers - signal.signal(signal.SIGTERM, original_sigterm) - signal.signal(signal.SIGINT, original_sigint) - - async def close_toolset_safely(toolset): - """Safely close a toolset with error handling.""" - try: - logger.info(f"Closing toolset: {type(toolset).__name__}") - await toolset.close() - logger.info(f"Successfully closed toolset: {type(toolset).__name__}") - except Exception as e: - logger.error(f"Error closing toolset {type(toolset).__name__}: {e}") + # Create tasks for all runner closures to run concurrently + await cleanup.close_runners(list(runner_dict.values())) # Run the FastAPI server. app = FastAPI(lifespan=internal_lifespan) @@ -903,16 +841,6 @@ def get_fast_api_app( for task in pending: task.cancel() - def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]: - toolsets = set() - if isinstance(agent, LlmAgent): - for tool_union in agent.tools: - if isinstance(tool_union, BaseToolset): - toolsets.add(tool_union) - for sub_agent in agent.sub_agents: - toolsets.update(_get_all_toolsets(sub_agent)) - return toolsets - async def _get_root_agent_async(app_name: str) -> Agent: """Returns the root agent for the given app.""" if app_name in root_agent_dict: @@ -924,7 +852,6 @@ def get_fast_api_app( raise ValueError(f'Unable to find "root_agent" from {app_name}.') root_agent_dict[app_name] = root_agent - toolsets_to_close.update(_get_all_toolsets(root_agent)) return root_agent async def _get_runner_async(app_name: str) -> Runner: diff --git a/src/google/adk/cli/utils/cleanup.py b/src/google/adk/cli/utils/cleanup.py new file mode 100644 index 0000000..137c52c --- /dev/null +++ b/src/google/adk/cli/utils/cleanup.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from typing import List + +from ...runners import Runner + +logger = logging.getLogger("google_adk." + __name__) + + +async def close_runners(runners: List[Runner]) -> None: + cleanup_tasks = [asyncio.create_task(runner.close()) for runner in runners] + if cleanup_tasks: + # Wait for all cleanup tasks with timeout + done, pending = await asyncio.wait( + cleanup_tasks, + timeout=30.0, # 30 second timeout for cleanup + return_when=asyncio.ALL_COMPLETED, + ) + + # If any tasks are still pending, log it + if pending: + logger.warning( + "%s runner close tasks didn't complete in time", len(pending) + ) + for task in pending: + task.cancel() diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 2ebe1e8..220c5d2 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -42,6 +42,7 @@ from .sessions.base_session_service import BaseSessionService from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry import tracer +from .tools.base_toolset import BaseToolset logger = logging.getLogger('google_adk.' + __name__) @@ -457,6 +458,37 @@ class Runner: run_config=run_config, ) + def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: + toolsets = set() + if isinstance(agent, LlmAgent): + for tool_union in agent.tools: + if isinstance(tool_union, BaseToolset): + toolsets.add(tool_union) + for sub_agent in agent.sub_agents: + toolsets.update(self._collect_toolset(sub_agent)) + return toolsets + + async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): + """Clean up toolsets with proper task context management.""" + if not toolsets_to_close: + return + + # This maintains the same task context throughout cleanup + for toolset in toolsets_to_close: + try: + logger.info('Closing toolset: %s', type(toolset).__name__) + # Use asyncio.wait_for to add timeout protection + await asyncio.wait_for(toolset.close(), timeout=10.0) + logger.info('Successfully closed toolset: %s', type(toolset).__name__) + except asyncio.TimeoutError: + logger.warning('Toolset %s cleanup timed out', type(toolset).__name__) + except Exception as e: + logger.error('Error closing toolset %s: %s', type(toolset).__name__, e) + + async def close(self): + """Closes the runner.""" + await self._cleanup_toolsets(self._collect_toolset(self.agent)) + class InMemoryRunner(Runner): """An in-memory Runner for testing and development. diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 6cbae96..cb14a90 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from contextlib import asynccontextmanager + from contextlib import AsyncExitStack import functools import logging @@ -71,29 +70,27 @@ def retry_on_closed_resource(async_reinit_func_name: str): Usage: class MCPTool: - ... - async def create_session(self): - self.session = ... + ... + async def create_session(self): + self.session = ... - @retry_on_closed_resource('create_session') - async def use_session(self): - await self.session.call_tool() + @retry_on_closed_resource('create_session') + async def use_session(self): + await self.session.call_tool() Args: - async_reinit_func_name: The name of the async function to recreate session. + async_reinit_func_name: The name of the async function to recreate session. Returns: - The decorated function. + The decorated function. """ def decorator(func): - @functools.wraps( - func - ) # Preserves original function metadata (name, docstring) + @functools.wraps(func) # Preserves original function metadata async def wrapper(self, *args, **kwargs): try: return await func(self, *args, **kwargs) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError as close_err: try: if hasattr(self, async_reinit_func_name) and callable( getattr(self, async_reinit_func_name) @@ -105,7 +102,7 @@ def retry_on_closed_resource(async_reinit_func_name: str): f'Function {async_reinit_func_name} does not exist in decorated' ' class. Please check the function name in' ' retry_on_closed_resource decorator.' - ) + ) from close_err except Exception as reinit_err: raise RuntimeError( f'Error reinitializing: {reinit_err}' @@ -117,45 +114,6 @@ def retry_on_closed_resource(async_reinit_func_name: str): return decorator -@asynccontextmanager -async def tracked_stdio_client(server, errlog, process=None): - """A wrapper around stdio_client that ensures proper process tracking and cleanup.""" - our_process = process - - # If no process was provided, create one - if our_process is None: - our_process = await asyncio.create_subprocess_exec( - server.command, - *server.args, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=errlog, - ) - - # Use the original stdio_client, but ensure process cleanup - try: - async with stdio_client(server=server, errlog=errlog) as client: - yield client, our_process - finally: - # Ensure the process is properly terminated if it still exists - if our_process and our_process.returncode is None: - try: - logger.info( - f'Terminating process {our_process.pid} from tracked_stdio_client' - ) - our_process.terminate() - try: - await asyncio.wait_for(our_process.wait(), timeout=3.0) - except asyncio.TimeoutError: - # Force kill if it doesn't terminate quickly - if our_process.returncode is None: - logger.warning(f'Forcing kill of process {our_process.pid}') - our_process.kill() - except ProcessLookupError: - # Process already gone, that's fine - logger.info(f'Process {our_process.pid} already terminated') - - class MCPSessionManager: """Manages MCP client sessions. @@ -166,162 +124,78 @@ class MCPSessionManager: def __init__( self, connection_params: StdioServerParameters | SseServerParams, - exit_stack: AsyncExitStack, errlog: TextIO = sys.stderr, ): """Initializes the MCP session manager. - Example usage: - ``` - mcp_session_manager = MCPSessionManager( - connection_params=connection_params, - exit_stack=exit_stack, - ) - session = await mcp_session_manager.create_session() - ``` - Args: connection_params: Parameters for the MCP connection (Stdio or SSE). - exit_stack: AsyncExitStack to manage the session lifecycle. errlog: (Optional) TextIO stream for error logging. Use only for initializing a local stdio MCP session. """ - self._connection_params = connection_params - self._exit_stack = exit_stack self._errlog = errlog - self._process = None # Track the subprocess - self._active_processes = set() # Track all processes created - self._active_file_handles = set() # Track file handles + # Each session manager maintains its own exit stack for proper cleanup + self._exit_stack: Optional[AsyncExitStack] = None + self._session: Optional[ClientSession] = None - async def create_session( - self, - ) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]: - """Creates a new MCP session and tracks the associated process.""" - session, process = await self._initialize_session( - connection_params=self._connection_params, - exit_stack=self._exit_stack, - errlog=self._errlog, - ) - self._process = process # Store reference to process - - # Track the process - if process: - self._active_processes.add(process) - - return session, process - - @classmethod - async def _initialize_session( - cls, - *, - connection_params: StdioServerParameters | SseServerParams, - exit_stack: AsyncExitStack, - errlog: TextIO = sys.stderr, - ) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]: - """Initializes an MCP client session. - - Args: - connection_params: Parameters for the MCP connection (Stdio or SSE). - exit_stack: AsyncExitStack to manage the session lifecycle. - errlog: (Optional) TextIO stream for error logging. Use only for - initializing a local stdio MCP session. + async def create_session(self) -> ClientSession: + """Creates and initializes an MCP client session. Returns: ClientSession: The initialized MCP client session. """ - process = None + if self._session is not None: + return self._session - if isinstance(connection_params, StdioServerParameters): - # For stdio connections, we need to track the subprocess - client, process = await cls._create_stdio_client( - server=connection_params, - errlog=errlog, - exit_stack=exit_stack, - ) - elif isinstance(connection_params, SseServerParams): - # For SSE connections, create the client without a subprocess - client = sse_client( - url=connection_params.url, - headers=connection_params.headers, - timeout=connection_params.timeout, - sse_read_timeout=connection_params.sse_read_timeout, - ) - else: - raise ValueError( - 'Unable to initialize connection. Connection should be' - ' StdioServerParameters or SseServerParams, but got' - f' {connection_params}' - ) + # Create a new exit stack for this session + self._exit_stack = AsyncExitStack() - # Create the session with the client - transports = await exit_stack.enter_async_context(client) - session = await exit_stack.enter_async_context(ClientSession(*transports)) - await session.initialize() - - return session, process - - @staticmethod - async def _create_stdio_client( - server: StdioServerParameters, - errlog: TextIO, - exit_stack: AsyncExitStack, - ) -> tuple[Any, asyncio.subprocess.Process]: - """Create stdio client and return both the client and process. - - This implementation adapts to how the MCP stdio_client is created. - The actual implementation may need to be adjusted based on the MCP library - structure. - """ - # Create the subprocess directly so we can track it - process = await asyncio.create_subprocess_exec( - server.command, - *server.args, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=errlog, - ) - - # Create the stdio client using the MCP library try: - # Method 1: Try using the existing process if stdio_client supports it - client = stdio_client(server=server, errlog=errlog, process=process) - except TypeError: - # Method 2: If the above doesn't work, let stdio_client create its own process - # and we'll need to terminate both processes later - logger.warning( - 'Using stdio_client with its own process - may lead to duplicate' - ' processes' + if isinstance(self._connection_params, StdioServerParameters): + client = stdio_client( + server=self._connection_params, errlog=self._errlog + ) + elif isinstance(self._connection_params, SseServerParams): + client = sse_client( + url=self._connection_params.url, + headers=self._connection_params.headers, + timeout=self._connection_params.timeout, + sse_read_timeout=self._connection_params.sse_read_timeout, + ) + else: + raise ValueError( + 'Unable to initialize connection. Connection should be' + ' StdioServerParameters or SseServerParams, but got' + f' {self._connection_params}' + ) + + transports = await self._exit_stack.enter_async_context(client) + session = await self._exit_stack.enter_async_context( + ClientSession(*transports) ) - client = stdio_client(server=server, errlog=errlog) + await session.initialize() - return client, process + self._session = session + return session - async def _emergency_cleanup(self): - """Perform emergency cleanup of resources when normal cleanup fails.""" - logger.info('Performing emergency cleanup of MCPSessionManager resources') + except Exception: + # If session creation fails, clean up the exit stack + if self._exit_stack: + await self._exit_stack.aclose() + self._exit_stack = None + raise - # Clean up any tracked processes - for proc in list(self._active_processes): + async def close(self): + """Closes the session and cleans up resources.""" + if self._exit_stack: try: - if proc and proc.returncode is None: - logger.info(f'Emergency termination of process {proc.pid}') - proc.terminate() - try: - await asyncio.wait_for(proc.wait(), timeout=1.0) - except asyncio.TimeoutError: - logger.warning(f"Process {proc.pid} didn't terminate, forcing kill") - proc.kill() - self._active_processes.remove(proc) + await self._exit_stack.aclose() except Exception as e: - logger.error(f'Error during process cleanup: {e}') - - # Clean up any tracked file handles - for handle in list(self._active_file_handles): - try: - if not handle.closed: - logger.info('Closing file handle') - handle.close() - self._active_file_handles.remove(handle) - except Exception as e: - logger.error(f'Error closing file handle: {e}') + # Log the error but don't re-raise to avoid blocking shutdown + print( + f'Warning: Error during MCP session cleanup: {e}', file=self._errlog + ) + finally: + self._exit_stack = None + self._session = None diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 9aa184e..0ad6d04 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import logging from typing import Optional from google.genai.types import FunctionDeclaration @@ -23,7 +25,6 @@ from .mcp_session_manager import retry_on_closed_resource # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. try: - from mcp import ClientSession from mcp.types import Tool as McpBaseTool except ImportError as e: import sys @@ -43,6 +44,8 @@ from ..base_tool import BaseTool from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema from ..tool_context import ToolContext +logger = logging.getLogger("google_adk." + __name__) + class MCPTool(BaseTool): """Turns a MCP Tool into a Vertex Agent Framework Tool. @@ -53,44 +56,40 @@ class MCPTool(BaseTool): def __init__( self, + *, mcp_tool: McpBaseTool, - mcp_session: ClientSession, mcp_session_manager: MCPSessionManager, auth_scheme: Optional[AuthScheme] = None, - auth_credential: Optional[AuthCredential] | None = None, + auth_credential: Optional[AuthCredential] = None, ): """Initializes a MCPTool. - This tool wraps a MCP Tool interface and an active MCP Session. It invokes - the MCP Tool through executing the tool from remote MCP Session. - - Example: - tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session) + This tool wraps a MCP Tool interface and uses a session manager to + communicate with the MCP server. Args: mcp_tool: The MCP tool to wrap. - mcp_session: The MCP session to use to call the tool. + mcp_session_manager: The MCP session manager to use for communication. auth_scheme: The authentication scheme to use. auth_credential: The authentication credential to use. Raises: - ValueError: If mcp_tool or mcp_session is None. + ValueError: If mcp_tool or mcp_session_manager is None. """ if mcp_tool is None: raise ValueError("mcp_tool cannot be None") - if mcp_session is None: - raise ValueError("mcp_session cannot be None") - super().__init__(name=mcp_tool.name, description=mcp_tool.description or "") + if mcp_session_manager is None: + raise ValueError("mcp_session_manager cannot be None") + super().__init__( + name=mcp_tool.name, + description=mcp_tool.description if mcp_tool.description else "", + ) self._mcp_tool = mcp_tool - self._mcp_session = mcp_session self._mcp_session_manager = mcp_session_manager # TODO(cheliu): Support passing auth to MCP Server. self._auth_scheme = auth_scheme self._auth_credential = auth_credential - async def _reinitialize_session(self): - self._mcp_session = await self._mcp_session_manager.create_session() - @override def _get_declaration(self) -> FunctionDeclaration: """Gets the function declaration for the tool. @@ -105,7 +104,6 @@ class MCPTool(BaseTool): ) return function_decl - @override @retry_on_closed_resource("_reinitialize_session") async def run_async(self, *, args, tool_context: ToolContext): """Runs the tool asynchronously. @@ -117,10 +115,15 @@ class MCPTool(BaseTool): Returns: Any: The response from the tool. """ + # Get the session from the session manager + session = await self._mcp_session_manager.create_session() + # TODO(cheliu): Support passing tool context to MCP Server. - try: - response = await self._mcp_session.call_tool(self.name, arguments=args) - return response - except Exception as e: - print(e) - raise e + response = await session.call_tool(self.name, arguments=args) + return response + + async def _reinitialize_session(self): + """Reinitializes the session when connection is lost.""" + # Close the old session and create a new one + await self._mcp_session_manager.close() + await self._mcp_session_manager.create_session() diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 994f6b9..01e586c 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from contextlib import AsyncExitStack import logging -import os -import signal import sys from typing import List from typing import Optional from typing import TextIO from typing import Union -from typing_extensions import override - from ...agents.readonly_context import ReadonlyContext from ..base_tool import BaseTool from ..base_toolset import BaseToolset @@ -36,7 +30,6 @@ from .mcp_session_manager import SseServerParams # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. try: - from mcp import ClientSession from mcp import StdioServerParameters from mcp.types import ListToolsResult except ImportError as e: @@ -58,16 +51,31 @@ logger = logging.getLogger("google_adk." + __name__) class MCPToolset(BaseToolset): """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. + This toolset manages the connection to an MCP server and provides tools + that can be used by an agent. It properly implements the BaseToolset + interface for easy integration with the agent framework. + Usage: - ``` - root_agent = LlmAgent( - tools=MCPToolset( - connection_params=StdioServerParameters( - command='npx', - args=["-y", "@modelcontextprotocol/server-filesystem"], - ) - ) + ```python + toolset = MCPToolset( + connection_params=StdioServerParameters( + command='npx', + args=["-y", "@modelcontextprotocol/server-filesystem"], + ), + tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools ) + + # Use in an agent + agent = LlmAgent( + model='gemini-2.0-flash', + name='enterprise_assistant', + instruction='Help user accessing their file systems', + tools=[toolset], + ) + + # Cleanup is handled automatically by the agent framework + # But you can also manually close if needed: + # await toolset.close() ``` """ @@ -75,140 +83,89 @@ class MCPToolset(BaseToolset): self, *, connection_params: StdioServerParameters | SseServerParams, - errlog: TextIO = sys.stderr, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + errlog: TextIO = sys.stderr, ): """Initializes the MCPToolset. Args: - connection_params: The connection parameters to the MCP server. Can be: - `StdioServerParameters` for using local mcp server (e.g. using `npx` or - `python3`); or `SseServerParams` for a local/remote SSE server. - errlog: (Optional) TextIO stream for error logging. Use only for - initializing a local stdio MCP session. + connection_params: The connection parameters to the MCP server. Can be: + `StdioServerParameters` for using local mcp server (e.g. using `npx` or + `python3`); or `SseServerParams` for a local/remote SSE server. + tool_filter: Optional filter to select specific tools. Can be either: + - A list of tool names to include + - A ToolPredicate function for custom filtering logic + errlog: TextIO stream for error logging. """ + super().__init__(tool_filter=tool_filter) if not connection_params: raise ValueError("Missing connection params in MCPToolset.") - super().__init__(tool_filter=tool_filter) + self._connection_params = connection_params self._errlog = errlog - self._exit_stack = AsyncExitStack() - self._creator_task_id = None - self._process_pid = None # Store the subprocess PID - self._session_manager = MCPSessionManager( + # Create the session manager that will handle the MCP connection + self._mcp_session_manager = MCPSessionManager( connection_params=self._connection_params, - exit_stack=self._exit_stack, errlog=self._errlog, ) self._session = None - self._initialized = False - async def _initialize(self) -> ClientSession: - """Connects to the MCP Server and initializes the ClientSession.""" - # Store the current task ID when initializing - self._creator_task_id = id(asyncio.current_task()) - self._session, process = await self._session_manager.create_session() - # Store the process PID if available - if process and hasattr(process, "pid"): - self._process_pid = process.pid - self._initialized = True - return self._session - - @override - async def close(self): - """Safely closes the connection to MCP Server with guaranteed resource cleanup.""" - if not self._initialized: - return # Nothing to close - - logger.info("Closing MCP Toolset") - - # Step 1: Try graceful shutdown of the session if it exists - if self._session: - try: - logger.info("Attempting graceful session shutdown") - await self._session.shutdown() - except Exception as e: - logger.warning(f"Session shutdown error (continuing cleanup): {e}") - - # Step 2: Try to close the exit stack - try: - logger.info("Closing AsyncExitStack") - await self._exit_stack.aclose() - # If we get here, the exit stack closed successfully - logger.info("AsyncExitStack closed successfully") - return - except RuntimeError as e: - if "Attempted to exit cancel scope in a different task" in str(e): - logger.warning("Task mismatch during shutdown - using fallback cleanup") - # Continue to manual cleanup - else: - logger.error(f"Unexpected RuntimeError: {e}") - # Continue to manual cleanup - except Exception as e: - logger.error(f"Error during exit stack closure: {e}") - # Continue to manual cleanup - - # Step 3: Manual cleanup of the subprocess if we have its PID - if self._process_pid: - await self._ensure_process_terminated(self._process_pid) - - # Step 4: Ask the session manager to do any additional cleanup it can - await self._session_manager._emergency_cleanup() - - async def _ensure_process_terminated(self, pid): - """Ensure a process is terminated using its PID.""" - try: - # Check if process exists - os.kill(pid, 0) # This just checks if the process exists - - logger.info(f"Terminating process with PID {pid}") - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - - # Give it a moment to terminate - for _ in range(30): # wait up to 3 seconds - await asyncio.sleep(0.1) - try: - os.kill(pid, 0) # Process still exists - except ProcessLookupError: - logger.info(f"Process {pid} terminated successfully") - return - - # If we get here, process didn't terminate gracefully - logger.warning( - f"Process {pid} didn't terminate gracefully, using SIGKILL" - ) - os.kill(pid, signal.SIGKILL) - - except ProcessLookupError: - logger.info(f"Process {pid} already terminated") - except Exception as e: - logger.error(f"Error terminating process {pid}: {e}") - - @retry_on_closed_resource("_initialize") - @override + @retry_on_closed_resource("_reinitialize_session") async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, - ) -> List[MCPTool]: - """Loads all tools from the MCP Server. + ) -> List[BaseTool]: + """Return all tools in the toolset based on the provided context. + + Args: + readonly_context: Context used to filter tools available to the agent. + If None, all tools in the toolset are returned. Returns: - A list of MCPTools imported from the MCP Server. + List[BaseTool]: A list of tools available under the specified context. """ + # Get session from session manager if not self._session: - await self._initialize() + self._session = await self._mcp_session_manager.create_session() + + # Fetch available tools from the MCP server tools_response: ListToolsResult = await self._session.list_tools() + + # Apply filtering based on context and tool_filter tools = [] for tool in tools_response.tools: mcp_tool = MCPTool( mcp_tool=tool, - mcp_session=self._session, - mcp_session_manager=self._session_manager, + mcp_session_manager=self._mcp_session_manager, ) if self._is_tool_selected(mcp_tool, readonly_context): tools.append(mcp_tool) return tools + + async def _reinitialize_session(self): + """Reinitializes the session when connection is lost.""" + # Close the old session and clear cache + await self._mcp_session_manager.close() + self._session = await self._mcp_session_manager.create_session() + + # Tools will be reloaded on next get_tools call + + async def close(self) -> None: + """Performs cleanup and releases resources held by the toolset. + + This method closes the MCP session and cleans up all associated resources. + It's designed to be safe to call multiple times and handles cleanup errors + gracefully to avoid blocking application shutdown. + """ + try: + await self._mcp_session_manager.close() + except Exception as e: + # Log the error but don't re-raise to avoid blocking shutdown + print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog) + finally: + # Clear cached tools + self._tools_cache = None + self._tools_loaded = False