refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners

PiperOrigin-RevId: 762463028
This commit is contained in:
Xiang (Sean) Zhou
2025-05-23 09:48:48 -07:00
committed by Copybara-Service
parent b9b2c3fb54
commit 92c37496d3
6 changed files with 239 additions and 406 deletions
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack
import functools
import logging
@@ -71,29 +70,27 @@ def retry_on_closed_resource(async_reinit_func_name: str):
Usage:
class MCPTool:
...
async def create_session(self):
self.session = ...
...
async def create_session(self):
self.session = ...
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
Args:
async_reinit_func_name: The name of the async function to recreate session.
async_reinit_func_name: The name of the async function to recreate session.
Returns:
The decorated function.
The decorated function.
"""
def decorator(func):
@functools.wraps(
func
) # Preserves original function metadata (name, docstring)
@functools.wraps(func) # Preserves original function metadata
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
except anyio.ClosedResourceError as close_err:
try:
if hasattr(self, async_reinit_func_name) and callable(
getattr(self, async_reinit_func_name)
@@ -105,7 +102,7 @@ def retry_on_closed_resource(async_reinit_func_name: str):
f'Function {async_reinit_func_name} does not exist in decorated'
' class. Please check the function name in'
' retry_on_closed_resource decorator.'
)
) from close_err
except Exception as reinit_err:
raise RuntimeError(
f'Error reinitializing: {reinit_err}'
@@ -117,45 +114,6 @@ def retry_on_closed_resource(async_reinit_func_name: str):
return decorator
@asynccontextmanager
async def tracked_stdio_client(server, errlog, process=None):
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
our_process = process
# If no process was provided, create one
if our_process is None:
our_process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Use the original stdio_client, but ensure process cleanup
try:
async with stdio_client(server=server, errlog=errlog) as client:
yield client, our_process
finally:
# Ensure the process is properly terminated if it still exists
if our_process and our_process.returncode is None:
try:
logger.info(
f'Terminating process {our_process.pid} from tracked_stdio_client'
)
our_process.terminate()
try:
await asyncio.wait_for(our_process.wait(), timeout=3.0)
except asyncio.TimeoutError:
# Force kill if it doesn't terminate quickly
if our_process.returncode is None:
logger.warning(f'Forcing kill of process {our_process.pid}')
our_process.kill()
except ProcessLookupError:
# Process already gone, that's fine
logger.info(f'Process {our_process.pid} already terminated')
class MCPSessionManager:
"""Manages MCP client sessions.
@@ -166,162 +124,78 @@ class MCPSessionManager:
def __init__(
self,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
):
"""Initializes the MCP session manager.
Example usage:
```
mcp_session_manager = MCPSessionManager(
connection_params=connection_params,
exit_stack=exit_stack,
)
session = await mcp_session_manager.create_session()
```
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self._connection_params = connection_params
self._exit_stack = exit_stack
self._errlog = errlog
self._process = None # Track the subprocess
self._active_processes = set() # Track all processes created
self._active_file_handles = set() # Track file handles
# Each session manager maintains its own exit stack for proper cleanup
self._exit_stack: Optional[AsyncExitStack] = None
self._session: Optional[ClientSession] = None
async def create_session(
self,
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Creates a new MCP session and tracks the associated process."""
session, process = await self._initialize_session(
connection_params=self._connection_params,
exit_stack=self._exit_stack,
errlog=self._errlog,
)
self._process = process # Store reference to process
# Track the process
if process:
self._active_processes.add(process)
return session, process
@classmethod
async def _initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
"""Initializes an MCP client session.
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
async def create_session(self) -> ClientSession:
"""Creates and initializes an MCP client session.
Returns:
ClientSession: The initialized MCP client session.
"""
process = None
if self._session is not None:
return self._session
if isinstance(connection_params, StdioServerParameters):
# For stdio connections, we need to track the subprocess
client, process = await cls._create_stdio_client(
server=connection_params,
errlog=errlog,
exit_stack=exit_stack,
)
elif isinstance(connection_params, SseServerParams):
# For SSE connections, create the client without a subprocess
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
timeout=connection_params.timeout,
sse_read_timeout=connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {connection_params}'
)
# Create a new exit stack for this session
self._exit_stack = AsyncExitStack()
# Create the session with the client
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session, process
@staticmethod
async def _create_stdio_client(
server: StdioServerParameters,
errlog: TextIO,
exit_stack: AsyncExitStack,
) -> tuple[Any, asyncio.subprocess.Process]:
"""Create stdio client and return both the client and process.
This implementation adapts to how the MCP stdio_client is created.
The actual implementation may need to be adjusted based on the MCP library
structure.
"""
# Create the subprocess directly so we can track it
process = await asyncio.create_subprocess_exec(
server.command,
*server.args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=errlog,
)
# Create the stdio client using the MCP library
try:
# Method 1: Try using the existing process if stdio_client supports it
client = stdio_client(server=server, errlog=errlog, process=process)
except TypeError:
# Method 2: If the above doesn't work, let stdio_client create its own process
# and we'll need to terminate both processes later
logger.warning(
'Using stdio_client with its own process - may lead to duplicate'
' processes'
if isinstance(self._connection_params, StdioServerParameters):
client = stdio_client(
server=self._connection_params, errlog=self._errlog
)
elif isinstance(self._connection_params, SseServerParams):
client = sse_client(
url=self._connection_params.url,
headers=self._connection_params.headers,
timeout=self._connection_params.timeout,
sse_read_timeout=self._connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self._connection_params}'
)
transports = await self._exit_stack.enter_async_context(client)
session = await self._exit_stack.enter_async_context(
ClientSession(*transports)
)
client = stdio_client(server=server, errlog=errlog)
await session.initialize()
return client, process
self._session = session
return session
async def _emergency_cleanup(self):
"""Perform emergency cleanup of resources when normal cleanup fails."""
logger.info('Performing emergency cleanup of MCPSessionManager resources')
except Exception:
# If session creation fails, clean up the exit stack
if self._exit_stack:
await self._exit_stack.aclose()
self._exit_stack = None
raise
# Clean up any tracked processes
for proc in list(self._active_processes):
async def close(self):
"""Closes the session and cleans up resources."""
if self._exit_stack:
try:
if proc and proc.returncode is None:
logger.info(f'Emergency termination of process {proc.pid}')
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=1.0)
except asyncio.TimeoutError:
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
proc.kill()
self._active_processes.remove(proc)
await self._exit_stack.aclose()
except Exception as e:
logger.error(f'Error during process cleanup: {e}')
# Clean up any tracked file handles
for handle in list(self._active_file_handles):
try:
if not handle.closed:
logger.info('Closing file handle')
handle.close()
self._active_file_handles.remove(handle)
except Exception as e:
logger.error(f'Error closing file handle: {e}')
# Log the error but don't re-raise to avoid blocking shutdown
print(
f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
)
finally:
self._exit_stack = None
self._session = None