refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners

PiperOrigin-RevId: 762463028
This commit is contained in:
Xiang (Sean) Zhou 2025-05-23 09:48:48 -07:00 committed by Copybara-Service
parent b9b2c3fb54
commit 92c37496d3
6 changed files with 239 additions and 406 deletions

View File

@ -16,12 +16,9 @@
import asyncio
from contextlib import asynccontextmanager
import importlib
import inspect
import json
import logging
import os
from pathlib import Path
import signal
import sys
import time
import traceback
@ -55,11 +52,9 @@ from starlette.types import Lifespan
from typing_extensions import override
from ..agents import RunConfig
from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.llm_agent import LlmAgent
from ..agents.run_config import StreamingMode
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..evaluation.eval_case import EvalCase
@ -75,12 +70,12 @@ from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService
from ..tools.base_toolset import BaseToolset
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalCaseResult
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
from .cli_eval import EvalMetricResultPerInvocation
from .cli_eval import EvalSetResult
from .cli_eval import EvalStatus
from .utils import cleanup
from .utils import common
from .utils import create_empty_state
from .utils import envs
@ -230,27 +225,8 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
toolsets_to_close: set[BaseToolset] = set()
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
# Set up signal handlers for graceful shutdown
original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT)
def cleanup_handler(sig, frame):
# Log the signal
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
# Do synchronous cleanup if needed
# Then call original handler if it exists
if sig == signal.SIGTERM and callable(original_sigterm):
original_sigterm(sig, frame)
elif sig == signal.SIGINT and callable(original_sigint):
original_sigint(sig, frame)
# Install cleanup handlers
signal.signal(signal.SIGTERM, cleanup_handler)
signal.signal(signal.SIGINT, cleanup_handler)
try:
if lifespan:
@ -259,46 +235,8 @@ def get_fast_api_app(
else:
yield
finally:
# During shutdown, properly clean up all toolsets
logger.info(
"Server shutdown initiated, cleaning up %s toolsets",
len(toolsets_to_close),
)
# Create tasks for all toolset closures to run concurrently
cleanup_tasks = []
for toolset in toolsets_to_close:
task = asyncio.create_task(close_toolset_safely(toolset))
cleanup_tasks.append(task)
if cleanup_tasks:
# Wait for all cleanup tasks with timeout
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=10.0, # 10 second timeout for cleanup
return_when=asyncio.ALL_COMPLETED,
)
# If any tasks are still pending, log it
if pending:
logger.warning(
f"{len(pending)} toolset cleanup tasks didn't complete in time"
)
for task in pending:
task.cancel()
# Restore original signal handlers
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint)
async def close_toolset_safely(toolset):
"""Safely close a toolset with error handling."""
try:
logger.info(f"Closing toolset: {type(toolset).__name__}")
await toolset.close()
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
except Exception as e:
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
# Create tasks for all runner closures to run concurrently
await cleanup.close_runners(list(runner_dict.values()))
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)
@ -903,16 +841,6 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
for tool_union in agent.tools:
if isinstance(tool_union, BaseToolset):
toolsets.add(tool_union)
for sub_agent in agent.sub_agents:
toolsets.update(_get_all_toolsets(sub_agent))
return toolsets
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
@ -924,7 +852,6 @@ def get_fast_api_app(
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
root_agent_dict[app_name] = root_agent
toolsets_to_close.update(_get_all_toolsets(root_agent))
return root_agent
async def _get_runner_async(app_name: str) -> Runner:

View File

@ -0,0 +1,40 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import List
from ...runners import Runner
logger = logging.getLogger("google_adk." + __name__)
async def close_runners(runners: List[Runner]) -> None:
cleanup_tasks = [asyncio.create_task(runner.close()) for runner in runners]
if cleanup_tasks:
# Wait for all cleanup tasks with timeout
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=30.0, # 30 second timeout for cleanup
return_when=asyncio.ALL_COMPLETED,
)
# If any tasks are still pending, log it
if pending:
logger.warning(
"%s runner close tasks didn't complete in time", len(pending)
)
for task in pending:
task.cancel()

View File

@ -42,6 +42,7 @@ from .sessions.base_session_service import BaseSessionService
from .sessions.in_memory_session_service import InMemorySessionService
from .sessions.session import Session
from .telemetry import tracer
from .tools.base_toolset import BaseToolset
logger = logging.getLogger('google_adk.' + __name__)
@ -457,6 +458,37 @@ class Runner:
run_config=run_config,
)
def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
for tool_union in agent.tools:
if isinstance(tool_union, BaseToolset):
toolsets.add(tool_union)
for sub_agent in agent.sub_agents:
toolsets.update(self._collect_toolset(sub_agent))
return toolsets
async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
"""Clean up toolsets with proper task context management."""
if not toolsets_to_close:
return
# This maintains the same task context throughout cleanup
for toolset in toolsets_to_close:
try:
logger.info('Closing toolset: %s', type(toolset).__name__)
# Use asyncio.wait_for to add timeout protection
await asyncio.wait_for(toolset.close(), timeout=10.0)
logger.info('Successfully closed toolset: %s', type(toolset).__name__)
except asyncio.TimeoutError:
logger.warning('Toolset %s cleanup timed out', type(toolset).__name__)
except Exception as e:
logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)
async def close(self):
"""Closes the runner."""
await self._cleanup_toolsets(self._collect_toolset(self.agent))
class InMemoryRunner(Runner):
"""An in-memory Runner for testing and development.

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack
import functools
import logging
@ -71,29 +70,27 @@ def retry_on_closed_resource(async_reinit_func_name: str):
Usage:
class MCPTool:
...
async def create_session(self):
self.session = ...
...
async def create_session(self):
self.session = ...
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
Args:
async_reinit_func_name: The name of the async function to recreate session.
async_reinit_func_name: The name of the async function to recreate session.
Returns:
The decorated function.
The decorated function.
"""
def decorator(func):
@functools.wraps(
func
) # Preserves original function metadata (name, docstring)
@functools.wraps(func) # Preserves original function metadata
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
except anyio.ClosedResourceError as close_err:
try:
if hasattr(self, async_reinit_func_name) and callable(
getattr(self, async_reinit_func_name)
@ -105,7 +102,7 @@ def retry_on_closed_resource(async_reinit_func_name: str):
f'Function {async_reinit_func_name} does not exist in decorated'
' class. Please check the function name in'
' retry_on_closed_resource decorator.'
)
) from close_err
except Exception as reinit_err:
raise RuntimeError(
f'Error reinitializing: {reinit_err}'
@ -117,45 +114,6 @@ def retry_on_closed_resource(async_reinit_func_name: str):
return decorator
@asynccontextmanager
async def tracked_stdio_client(server, errlog, process=None):
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
our_process = process
# If no process was provided, create one
if our_process is None:
our_process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Use the original stdio_client, but ensure process cleanup
try:
async with stdio_client(server=server, errlog=errlog) as client:
yield client, our_process
finally:
# Ensure the process is properly terminated if it still exists
if our_process and our_process.returncode is None:
try:
logger.info(
f'Terminating process {our_process.pid} from tracked_stdio_client'
)
our_process.terminate()
try:
await asyncio.wait_for(our_process.wait(), timeout=3.0)
except asyncio.TimeoutError:
# Force kill if it doesn't terminate quickly
if our_process.returncode is None:
logger.warning(f'Forcing kill of process {our_process.pid}')
our_process.kill()
except ProcessLookupError:
# Process already gone, that's fine
logger.info(f'Process {our_process.pid} already terminated')
class MCPSessionManager:
"""Manages MCP client sessions.
@ -166,162 +124,78 @@ class MCPSessionManager:
def __init__(
self,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
):
"""Initializes the MCP session manager.
Example usage:
```
mcp_session_manager = MCPSessionManager(
connection_params=connection_params,
exit_stack=exit_stack,
)
session = await mcp_session_manager.create_session()
```
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self._connection_params = connection_params
self._exit_stack = exit_stack
self._errlog = errlog
self._process = None # Track the subprocess
self._active_processes = set() # Track all processes created
self._active_file_handles = set() # Track file handles
# Each session manager maintains its own exit stack for proper cleanup
self._exit_stack: Optional[AsyncExitStack] = None
self._session: Optional[ClientSession] = None
async def create_session(
self,
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Creates a new MCP session and tracks the associated process."""
session, process = await self._initialize_session(
connection_params=self._connection_params,
exit_stack=self._exit_stack,
errlog=self._errlog,
)
self._process = process # Store reference to process
# Track the process
if process:
self._active_processes.add(process)
return session, process
@classmethod
async def _initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Initializes an MCP client session.
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
async def create_session(self) -> ClientSession:
"""Creates and initializes an MCP client session.
Returns:
ClientSession: The initialized MCP client session.
"""
process = None
if self._session is not None:
return self._session
if isinstance(connection_params, StdioServerParameters):
# For stdio connections, we need to track the subprocess
client, process = await cls._create_stdio_client(
server=connection_params,
errlog=errlog,
exit_stack=exit_stack,
)
elif isinstance(connection_params, SseServerParams):
# For SSE connections, create the client without a subprocess
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
timeout=connection_params.timeout,
sse_read_timeout=connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {connection_params}'
)
# Create a new exit stack for this session
self._exit_stack = AsyncExitStack()
# Create the session with the client
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session, process
@staticmethod
async def _create_stdio_client(
server: StdioServerParameters,
errlog: TextIO,
exit_stack: AsyncExitStack,
) -> tuple[Any, asyncio.subprocess.Process]:
"""Create stdio client and return both the client and process.
This implementation adapts to how the MCP stdio_client is created.
The actual implementation may need to be adjusted based on the MCP library
structure.
"""
# Create the subprocess directly so we can track it
process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Create the stdio client using the MCP library
try:
# Method 1: Try using the existing process if stdio_client supports it
client = stdio_client(server=server, errlog=errlog, process=process)
except TypeError:
# Method 2: If the above doesn't work, let stdio_client create its own process
# and we'll need to terminate both processes later
logger.warning(
'Using stdio_client with its own process - may lead to duplicate'
' processes'
if isinstance(self._connection_params, StdioServerParameters):
client = stdio_client(
server=self._connection_params, errlog=self._errlog
)
elif isinstance(self._connection_params, SseServerParams):
client = sse_client(
url=self._connection_params.url,
headers=self._connection_params.headers,
timeout=self._connection_params.timeout,
sse_read_timeout=self._connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self._connection_params}'
)
transports = await self._exit_stack.enter_async_context(client)
session = await self._exit_stack.enter_async_context(
ClientSession(*transports)
)
client = stdio_client(server=server, errlog=errlog)
await session.initialize()
return client, process
self._session = session
return session
async def _emergency_cleanup(self):
"""Perform emergency cleanup of resources when normal cleanup fails."""
logger.info('Performing emergency cleanup of MCPSessionManager resources')
except Exception:
# If session creation fails, clean up the exit stack
if self._exit_stack:
await self._exit_stack.aclose()
self._exit_stack = None
raise
# Clean up any tracked processes
for proc in list(self._active_processes):
async def close(self):
"""Closes the session and cleans up resources."""
if self._exit_stack:
try:
if proc and proc.returncode is None:
logger.info(f'Emergency termination of process {proc.pid}')
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=1.0)
except asyncio.TimeoutError:
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
proc.kill()
self._active_processes.remove(proc)
await self._exit_stack.aclose()
except Exception as e:
logger.error(f'Error during process cleanup: {e}')
# Clean up any tracked file handles
for handle in list(self._active_file_handles):
try:
if not handle.closed:
logger.info('Closing file handle')
handle.close()
self._active_file_handles.remove(handle)
except Exception as e:
logger.error(f'Error closing file handle: {e}')
# Log the error but don't re-raise to avoid blocking shutdown
print(
f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
)
finally:
self._exit_stack = None
self._session = None

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from google.genai.types import FunctionDeclaration
@ -23,7 +25,6 @@ from .mcp_session_manager import retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession
from mcp.types import Tool as McpBaseTool
except ImportError as e:
import sys
@ -43,6 +44,8 @@ from ..base_tool import BaseTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from ..tool_context import ToolContext
logger = logging.getLogger("google_adk." + __name__)
class MCPTool(BaseTool):
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
@ -53,44 +56,40 @@ class MCPTool(BaseTool):
def __init__(
self,
*,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
auth_credential: Optional[AuthCredential] = None,
):
"""Initializes a MCPTool.
This tool wraps a MCP Tool interface and an active MCP Session. It invokes
the MCP Tool through executing the tool from remote MCP Session.
Example:
tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
This tool wraps a MCP Tool interface and uses a session manager to
communicate with the MCP server.
Args:
mcp_tool: The MCP tool to wrap.
mcp_session: The MCP session to use to call the tool.
mcp_session_manager: The MCP session manager to use for communication.
auth_scheme: The authentication scheme to use.
auth_credential: The authentication credential to use.
Raises:
ValueError: If mcp_tool or mcp_session is None.
ValueError: If mcp_tool or mcp_session_manager is None.
"""
if mcp_tool is None:
raise ValueError("mcp_tool cannot be None")
if mcp_session is None:
raise ValueError("mcp_session cannot be None")
super().__init__(name=mcp_tool.name, description=mcp_tool.description or "")
if mcp_session_manager is None:
raise ValueError("mcp_session_manager cannot be None")
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description if mcp_tool.description else "",
)
self._mcp_tool = mcp_tool
self._mcp_session = mcp_session
self._mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server.
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
async def _reinitialize_session(self):
self._mcp_session = await self._mcp_session_manager.create_session()
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
@ -105,7 +104,6 @@ class MCPTool(BaseTool):
)
return function_decl
@override
@retry_on_closed_resource("_reinitialize_session")
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
@ -117,10 +115,15 @@ class MCPTool(BaseTool):
Returns:
Any: The response from the tool.
"""
# Get the session from the session manager
session = await self._mcp_session_manager.create_session()
# TODO(cheliu): Support passing tool context to MCP Server.
try:
response = await self._mcp_session.call_tool(self.name, arguments=args)
return response
except Exception as e:
print(e)
raise e
response = await session.call_tool(self.name, arguments=args)
return response
async def _reinitialize_session(self):
"""Reinitializes the session when connection is lost."""
# Close the old session and create a new one
await self._mcp_session_manager.close()
await self._mcp_session_manager.create_session()

View File

@ -12,19 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import AsyncExitStack
import logging
import os
import signal
import sys
from typing import List
from typing import Optional
from typing import TextIO
from typing import Union
from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from ..base_tool import BaseTool
from ..base_toolset import BaseToolset
@ -36,7 +30,6 @@ from .mcp_session_manager import SseServerParams
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession
from mcp import StdioServerParameters
from mcp.types import ListToolsResult
except ImportError as e:
@ -58,16 +51,31 @@ logger = logging.getLogger("google_adk." + __name__)
class MCPToolset(BaseToolset):
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
This toolset manages the connection to an MCP server and provides tools
that can be used by an agent. It properly implements the BaseToolset
interface for easy integration with the agent framework.
Usage:
```
root_agent = LlmAgent(
tools=MCPToolset(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
)
```python
toolset = MCPToolset(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
),
tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools
)
# Use in an agent
agent = LlmAgent(
model='gemini-2.0-flash',
name='enterprise_assistant',
instruction='Help user accessing their file systems',
tools=[toolset],
)
# Cleanup is handled automatically by the agent framework
# But you can also manually close if needed:
# await toolset.close()
```
"""
@ -75,140 +83,89 @@ class MCPToolset(BaseToolset):
self,
*,
connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
errlog: TextIO = sys.stderr,
):
"""Initializes the MCPToolset.
Args:
connection_params: The connection parameters to the MCP server. Can be:
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
`python3`); or `SseServerParams` for a local/remote SSE server.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
connection_params: The connection parameters to the MCP server. Can be:
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
`python3`); or `SseServerParams` for a local/remote SSE server.
tool_filter: Optional filter to select specific tools. Can be either:
- A list of tool names to include
- A ToolPredicate function for custom filtering logic
errlog: TextIO stream for error logging.
"""
super().__init__(tool_filter=tool_filter)
if not connection_params:
raise ValueError("Missing connection params in MCPToolset.")
super().__init__(tool_filter=tool_filter)
self._connection_params = connection_params
self._errlog = errlog
self._exit_stack = AsyncExitStack()
self._creator_task_id = None
self._process_pid = None # Store the subprocess PID
self._session_manager = MCPSessionManager(
# Create the session manager that will handle the MCP connection
self._mcp_session_manager = MCPSessionManager(
connection_params=self._connection_params,
exit_stack=self._exit_stack,
errlog=self._errlog,
)
self._session = None
self._initialized = False
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
# Store the current task ID when initializing
self._creator_task_id = id(asyncio.current_task())
self._session, process = await self._session_manager.create_session()
# Store the process PID if available
if process and hasattr(process, "pid"):
self._process_pid = process.pid
self._initialized = True
return self._session
@override
async def close(self):
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
if not self._initialized:
return # Nothing to close
logger.info("Closing MCP Toolset")
# Step 1: Try graceful shutdown of the session if it exists
if self._session:
try:
logger.info("Attempting graceful session shutdown")
await self._session.shutdown()
except Exception as e:
logger.warning(f"Session shutdown error (continuing cleanup): {e}")
# Step 2: Try to close the exit stack
try:
logger.info("Closing AsyncExitStack")
await self._exit_stack.aclose()
# If we get here, the exit stack closed successfully
logger.info("AsyncExitStack closed successfully")
return
except RuntimeError as e:
if "Attempted to exit cancel scope in a different task" in str(e):
logger.warning("Task mismatch during shutdown - using fallback cleanup")
# Continue to manual cleanup
else:
logger.error(f"Unexpected RuntimeError: {e}")
# Continue to manual cleanup
except Exception as e:
logger.error(f"Error during exit stack closure: {e}")
# Continue to manual cleanup
# Step 3: Manual cleanup of the subprocess if we have its PID
if self._process_pid:
await self._ensure_process_terminated(self._process_pid)
# Step 4: Ask the session manager to do any additional cleanup it can
await self._session_manager._emergency_cleanup()
async def _ensure_process_terminated(self, pid):
"""Ensure a process is terminated using its PID."""
try:
# Check if process exists
os.kill(pid, 0) # This just checks if the process exists
logger.info(f"Terminating process with PID {pid}")
# First try SIGTERM for graceful shutdown
os.kill(pid, signal.SIGTERM)
# Give it a moment to terminate
for _ in range(30): # wait up to 3 seconds
await asyncio.sleep(0.1)
try:
os.kill(pid, 0) # Process still exists
except ProcessLookupError:
logger.info(f"Process {pid} terminated successfully")
return
# If we get here, process didn't terminate gracefully
logger.warning(
f"Process {pid} didn't terminate gracefully, using SIGKILL"
)
os.kill(pid, signal.SIGKILL)
except ProcessLookupError:
logger.info(f"Process {pid} already terminated")
except Exception as e:
logger.error(f"Error terminating process {pid}: {e}")
@retry_on_closed_resource("_initialize")
@override
@retry_on_closed_resource("_reinitialize_session")
async def get_tools(
self,
readonly_context: Optional[ReadonlyContext] = None,
) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
) -> List[BaseTool]:
"""Return all tools in the toolset based on the provided context.
Args:
readonly_context: Context used to filter tools available to the agent.
If None, all tools in the toolset are returned.
Returns:
A list of MCPTools imported from the MCP Server.
List[BaseTool]: A list of tools available under the specified context.
"""
# Get session from session manager
if not self._session:
await self._initialize()
self._session = await self._mcp_session_manager.create_session()
# Fetch available tools from the MCP server
tools_response: ListToolsResult = await self._session.list_tools()
# Apply filtering based on context and tool_filter
tools = []
for tool in tools_response.tools:
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session=self._session,
mcp_session_manager=self._session_manager,
mcp_session_manager=self._mcp_session_manager,
)
if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
return tools
async def _reinitialize_session(self):
"""Reinitializes the session when connection is lost."""
# Close the old session and clear cache
await self._mcp_session_manager.close()
self._session = await self._mcp_session_manager.create_session()
# Tools will be reloaded on next get_tools call
async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.
This method closes the MCP session and cleans up all associated resources.
It's designed to be safe to call multiple times and handles cleanup errors
gracefully to avoid blocking application shutdown.
"""
try:
await self._mcp_session_manager.close()
except Exception as e:
# Log the error but don't re-raise to avoid blocking shutdown
print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog)
finally:
# Clear cached tools
self._tools_cache = None
self._tools_loaded = False