fix:fix mcp toolset close issue

PiperOrigin-RevId: 759636772
This commit is contained in:
Xiang (Sean) Zhou
2025-05-16 09:05:18 -07:00
committed by Copybara-Service
parent 12507dc6cc
commit 05a853bc91
3 changed files with 294 additions and 23 deletions
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AsyncExitStack
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
import functools
import logging
import sys
from typing import Any, TextIO
from typing import Any, Optional, TextIO
import anyio
from pydantic import BaseModel
@@ -34,6 +36,8 @@ except ImportError as e:
else:
raise e
logger = logging.getLogger(__name__)
class SseServerParams(BaseModel):
"""Parameters for the MCP SSE connection.
@@ -108,6 +112,45 @@ def retry_on_closed_resource(async_reinit_func_name: str):
return decorator
@asynccontextmanager
async def tracked_stdio_client(server, errlog, process=None):
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
our_process = process
# If no process was provided, create one
if our_process is None:
our_process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Use the original stdio_client, but ensure process cleanup
try:
async with stdio_client(server=server, errlog=errlog) as client:
yield client, our_process
finally:
# Ensure the process is properly terminated if it still exists
if our_process and our_process.returncode is None:
try:
logger.info(
f'Terminating process {our_process.pid} from tracked_stdio_client'
)
our_process.terminate()
try:
await asyncio.wait_for(our_process.wait(), timeout=3.0)
except asyncio.TimeoutError:
# Force kill if it doesn't terminate quickly
if our_process.returncode is None:
logger.warning(f'Forcing kill of process {our_process.pid}')
our_process.kill()
except ProcessLookupError:
# Process already gone, that's fine
logger.info(f'Process {our_process.pid} already terminated')
class MCPSessionManager:
"""Manages MCP client sessions.
@@ -138,25 +181,39 @@ class MCPSessionManager:
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self._connection_params = connection_params
self._exit_stack = exit_stack
self._errlog = errlog
self._process = None # Track the subprocess
self._active_processes = set() # Track all processes created
self._active_file_handles = set() # Track file handles
async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session(
async def create_session(
self,
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Creates a new MCP session and tracks the associated process."""
session, process = await self._initialize_session(
connection_params=self._connection_params,
exit_stack=self._exit_stack,
errlog=self._errlog,
)
self._process = process # Store reference to process
# Track the process
if process:
self._active_processes.add(process)
return session, process
@classmethod
async def initialize_session(
async def _initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Initializes an MCP client session.
Args:
@@ -168,9 +225,17 @@ class MCPSessionManager:
Returns:
ClientSession: The initialized MCP client session.
"""
process = None
if isinstance(connection_params, StdioServerParameters):
client = stdio_client(server=connection_params, errlog=errlog)
# For stdio connections, we need to track the subprocess
client, process = await cls._create_stdio_client(
server=connection_params,
errlog=errlog,
exit_stack=exit_stack,
)
elif isinstance(connection_params, SseServerParams):
# For SSE connections, create the client without a subprocess
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
@@ -184,7 +249,74 @@ class MCPSessionManager:
f' {connection_params}'
)
# Create the session with the client
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session
return session, process
@staticmethod
async def _create_stdio_client(
server: StdioServerParameters,
errlog: TextIO,
exit_stack: AsyncExitStack,
) -> tuple[Any, asyncio.subprocess.Process]:
"""Create stdio client and return both the client and process.
This implementation adapts to how the MCP stdio_client is created.
The actual implementation may need to be adjusted based on the MCP library
structure.
"""
# Create the subprocess directly so we can track it
process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Create the stdio client using the MCP library
try:
# Method 1: Try using the existing process if stdio_client supports it
client = stdio_client(server=server, errlog=errlog, process=process)
except TypeError:
# Method 2: If the above doesn't work, let stdio_client create its own process
# and we'll need to terminate both processes later
logger.warning(
'Using stdio_client with its own process - may lead to duplicate'
' processes'
)
client = stdio_client(server=server, errlog=errlog)
return client, process
async def _emergency_cleanup(self):
"""Perform emergency cleanup of resources when normal cleanup fails."""
logger.info('Performing emergency cleanup of MCPSessionManager resources')
# Clean up any tracked processes
for proc in list(self._active_processes):
try:
if proc and proc.returncode is None:
logger.info(f'Emergency termination of process {proc.pid}')
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=1.0)
except asyncio.TimeoutError:
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
proc.kill()
self._active_processes.remove(proc)
except Exception as e:
logger.error(f'Error during process cleanup: {e}')
# Clean up any tracked file handles
for handle in list(self._active_file_handles):
try:
if not handle.closed:
logger.info('Closing file handle')
handle.close()
self._active_file_handles.remove(handle)
except Exception as e:
logger.error(f'Error closing file handle: {e}')