mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
fix:fix mcp toolset close issue
PiperOrigin-RevId: 759636772
This commit is contained in:
parent
12507dc6cc
commit
05a853bc91
@ -21,6 +21,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@ -221,7 +222,7 @@ def get_fast_api_app(
|
|||||||
)
|
)
|
||||||
provider.add_span_processor(processor)
|
provider.add_span_processor(processor)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
||||||
" not be enabled."
|
" not be enabled."
|
||||||
)
|
)
|
||||||
@ -232,14 +233,71 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def internal_lifespan(app: FastAPI):
|
async def internal_lifespan(app: FastAPI):
|
||||||
if lifespan:
|
# Set up signal handlers for graceful shutdown
|
||||||
async with lifespan(app) as lifespan_context:
|
original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||||
yield
|
original_sigint = signal.getsignal(signal.SIGINT)
|
||||||
|
|
||||||
for toolset in toolsets_to_close:
|
def cleanup_handler(sig, frame):
|
||||||
await toolset.close()
|
# Log the signal
|
||||||
else:
|
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
|
||||||
yield
|
# Do synchronous cleanup if needed
|
||||||
|
# Then call original handler if it exists
|
||||||
|
if sig == signal.SIGTERM and callable(original_sigterm):
|
||||||
|
original_sigterm(sig, frame)
|
||||||
|
elif sig == signal.SIGINT and callable(original_sigint):
|
||||||
|
original_sigint(sig, frame)
|
||||||
|
|
||||||
|
# Install cleanup handlers
|
||||||
|
signal.signal(signal.SIGTERM, cleanup_handler)
|
||||||
|
signal.signal(signal.SIGINT, cleanup_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if lifespan:
|
||||||
|
async with lifespan(app) as lifespan_context:
|
||||||
|
yield lifespan_context
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# During shutdown, properly clean up all toolsets
|
||||||
|
logger.info(
|
||||||
|
"Server shutdown initiated, cleaning up %s toolsets",
|
||||||
|
len(toolsets_to_close),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tasks for all toolset closures to run concurrently
|
||||||
|
cleanup_tasks = []
|
||||||
|
for toolset in toolsets_to_close:
|
||||||
|
task = asyncio.create_task(close_toolset_safely(toolset))
|
||||||
|
cleanup_tasks.append(task)
|
||||||
|
|
||||||
|
if cleanup_tasks:
|
||||||
|
# Wait for all cleanup tasks with timeout
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
cleanup_tasks,
|
||||||
|
timeout=10.0, # 10 second timeout for cleanup
|
||||||
|
return_when=asyncio.ALL_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If any tasks are still pending, log it
|
||||||
|
if pending:
|
||||||
|
logger.warn(
|
||||||
|
f"{len(pending)} toolset cleanup tasks didn't complete in time"
|
||||||
|
)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Restore original signal handlers
|
||||||
|
signal.signal(signal.SIGTERM, original_sigterm)
|
||||||
|
signal.signal(signal.SIGINT, original_sigint)
|
||||||
|
|
||||||
|
async def close_toolset_safely(toolset):
|
||||||
|
"""Safely close a toolset with error handling."""
|
||||||
|
try:
|
||||||
|
logger.info(f"Closing toolset: {type(toolset).__name__}")
|
||||||
|
await toolset.close()
|
||||||
|
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
|
||||||
|
|
||||||
# Run the FastAPI server.
|
# Run the FastAPI server.
|
||||||
app = FastAPI(lifespan=internal_lifespan)
|
app = FastAPI(lifespan=internal_lifespan)
|
||||||
|
@ -12,10 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import AsyncExitStack
|
import asyncio
|
||||||
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, TextIO
|
from typing import Any, Optional, TextIO
|
||||||
import anyio
|
import anyio
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -34,6 +36,8 @@ except ImportError as e:
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SseServerParams(BaseModel):
|
class SseServerParams(BaseModel):
|
||||||
"""Parameters for the MCP SSE connection.
|
"""Parameters for the MCP SSE connection.
|
||||||
@ -108,6 +112,45 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def tracked_stdio_client(server, errlog, process=None):
|
||||||
|
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
|
||||||
|
our_process = process
|
||||||
|
|
||||||
|
# If no process was provided, create one
|
||||||
|
if our_process is None:
|
||||||
|
our_process = await asyncio.create_subprocess_exec(
|
||||||
|
server.command,
|
||||||
|
*server.args,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=errlog,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the original stdio_client, but ensure process cleanup
|
||||||
|
try:
|
||||||
|
async with stdio_client(server=server, errlog=errlog) as client:
|
||||||
|
yield client, our_process
|
||||||
|
finally:
|
||||||
|
# Ensure the process is properly terminated if it still exists
|
||||||
|
if our_process and our_process.returncode is None:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f'Terminating process {our_process.pid} from tracked_stdio_client'
|
||||||
|
)
|
||||||
|
our_process.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(our_process.wait(), timeout=3.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Force kill if it doesn't terminate quickly
|
||||||
|
if our_process.returncode is None:
|
||||||
|
logger.warning(f'Forcing kill of process {our_process.pid}')
|
||||||
|
our_process.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
# Process already gone, that's fine
|
||||||
|
logger.info(f'Process {our_process.pid} already terminated')
|
||||||
|
|
||||||
|
|
||||||
class MCPSessionManager:
|
class MCPSessionManager:
|
||||||
"""Manages MCP client sessions.
|
"""Manages MCP client sessions.
|
||||||
|
|
||||||
@ -138,25 +181,39 @@ class MCPSessionManager:
|
|||||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||||
initializing a local stdio MCP session.
|
initializing a local stdio MCP session.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._connection_params = connection_params
|
self._connection_params = connection_params
|
||||||
self._exit_stack = exit_stack
|
self._exit_stack = exit_stack
|
||||||
self._errlog = errlog
|
self._errlog = errlog
|
||||||
|
self._process = None # Track the subprocess
|
||||||
|
self._active_processes = set() # Track all processes created
|
||||||
|
self._active_file_handles = set() # Track file handles
|
||||||
|
|
||||||
async def create_session(self) -> ClientSession:
|
async def create_session(
|
||||||
return await MCPSessionManager.initialize_session(
|
self,
|
||||||
|
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
||||||
|
"""Creates a new MCP session and tracks the associated process."""
|
||||||
|
session, process = await self._initialize_session(
|
||||||
connection_params=self._connection_params,
|
connection_params=self._connection_params,
|
||||||
exit_stack=self._exit_stack,
|
exit_stack=self._exit_stack,
|
||||||
errlog=self._errlog,
|
errlog=self._errlog,
|
||||||
)
|
)
|
||||||
|
self._process = process # Store reference to process
|
||||||
|
|
||||||
|
# Track the process
|
||||||
|
if process:
|
||||||
|
self._active_processes.add(process)
|
||||||
|
|
||||||
|
return session, process
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def initialize_session(
|
async def _initialize_session(
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
connection_params: StdioServerParameters | SseServerParams,
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack,
|
||||||
errlog: TextIO = sys.stderr,
|
errlog: TextIO = sys.stderr,
|
||||||
) -> ClientSession:
|
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
||||||
"""Initializes an MCP client session.
|
"""Initializes an MCP client session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -168,9 +225,17 @@ class MCPSessionManager:
|
|||||||
Returns:
|
Returns:
|
||||||
ClientSession: The initialized MCP client session.
|
ClientSession: The initialized MCP client session.
|
||||||
"""
|
"""
|
||||||
|
process = None
|
||||||
|
|
||||||
if isinstance(connection_params, StdioServerParameters):
|
if isinstance(connection_params, StdioServerParameters):
|
||||||
client = stdio_client(server=connection_params, errlog=errlog)
|
# For stdio connections, we need to track the subprocess
|
||||||
|
client, process = await cls._create_stdio_client(
|
||||||
|
server=connection_params,
|
||||||
|
errlog=errlog,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
elif isinstance(connection_params, SseServerParams):
|
elif isinstance(connection_params, SseServerParams):
|
||||||
|
# For SSE connections, create the client without a subprocess
|
||||||
client = sse_client(
|
client = sse_client(
|
||||||
url=connection_params.url,
|
url=connection_params.url,
|
||||||
headers=connection_params.headers,
|
headers=connection_params.headers,
|
||||||
@ -184,7 +249,74 @@ class MCPSessionManager:
|
|||||||
f' {connection_params}'
|
f' {connection_params}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create the session with the client
|
||||||
transports = await exit_stack.enter_async_context(client)
|
transports = await exit_stack.enter_async_context(client)
|
||||||
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
return session
|
|
||||||
|
return session, process
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _create_stdio_client(
|
||||||
|
server: StdioServerParameters,
|
||||||
|
errlog: TextIO,
|
||||||
|
exit_stack: AsyncExitStack,
|
||||||
|
) -> tuple[Any, asyncio.subprocess.Process]:
|
||||||
|
"""Create stdio client and return both the client and process.
|
||||||
|
|
||||||
|
This implementation adapts to how the MCP stdio_client is created.
|
||||||
|
The actual implementation may need to be adjusted based on the MCP library
|
||||||
|
structure.
|
||||||
|
"""
|
||||||
|
# Create the subprocess directly so we can track it
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
server.command,
|
||||||
|
*server.args,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=errlog,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the stdio client using the MCP library
|
||||||
|
try:
|
||||||
|
# Method 1: Try using the existing process if stdio_client supports it
|
||||||
|
client = stdio_client(server=server, errlog=errlog, process=process)
|
||||||
|
except TypeError:
|
||||||
|
# Method 2: If the above doesn't work, let stdio_client create its own process
|
||||||
|
# and we'll need to terminate both processes later
|
||||||
|
logger.warning(
|
||||||
|
'Using stdio_client with its own process - may lead to duplicate'
|
||||||
|
' processes'
|
||||||
|
)
|
||||||
|
client = stdio_client(server=server, errlog=errlog)
|
||||||
|
|
||||||
|
return client, process
|
||||||
|
|
||||||
|
async def _emergency_cleanup(self):
|
||||||
|
"""Perform emergency cleanup of resources when normal cleanup fails."""
|
||||||
|
logger.info('Performing emergency cleanup of MCPSessionManager resources')
|
||||||
|
|
||||||
|
# Clean up any tracked processes
|
||||||
|
for proc in list(self._active_processes):
|
||||||
|
try:
|
||||||
|
if proc and proc.returncode is None:
|
||||||
|
logger.info(f'Emergency termination of process {proc.pid}')
|
||||||
|
proc.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(proc.wait(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
|
||||||
|
proc.kill()
|
||||||
|
self._active_processes.remove(proc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error during process cleanup: {e}')
|
||||||
|
|
||||||
|
# Clean up any tracked file handles
|
||||||
|
for handle in list(self._active_file_handles):
|
||||||
|
try:
|
||||||
|
if not handle.closed:
|
||||||
|
logger.info('Closing file handle')
|
||||||
|
handle.close()
|
||||||
|
self._active_file_handles.remove(handle)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error closing file handle: {e}')
|
||||||
|
@ -12,7 +12,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -39,14 +43,16 @@ except ImportError as e:
|
|||||||
|
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
|
||||||
' version.'
|
" version."
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
from .mcp_tool import MCPTool
|
from .mcp_tool import MCPTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPToolset(BaseToolset):
|
class MCPToolset(BaseToolset):
|
||||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||||
@ -82,10 +88,12 @@ class MCPToolset(BaseToolset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not connection_params:
|
if not connection_params:
|
||||||
raise ValueError('Missing connection params in MCPToolset.')
|
raise ValueError("Missing connection params in MCPToolset.")
|
||||||
self._connection_params = connection_params
|
self._connection_params = connection_params
|
||||||
self._errlog = errlog
|
self._errlog = errlog
|
||||||
self._exit_stack = AsyncExitStack()
|
self._exit_stack = AsyncExitStack()
|
||||||
|
self._creator_task_id = None
|
||||||
|
self._process_pid = None # Store the subprocess PID
|
||||||
|
|
||||||
self._session_manager = MCPSessionManager(
|
self._session_manager = MCPSessionManager(
|
||||||
connection_params=self._connection_params,
|
connection_params=self._connection_params,
|
||||||
@ -94,10 +102,17 @@ class MCPToolset(BaseToolset):
|
|||||||
)
|
)
|
||||||
self._session = None
|
self._session = None
|
||||||
self.tool_filter = tool_filter
|
self.tool_filter = tool_filter
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
async def _initialize(self) -> ClientSession:
|
async def _initialize(self) -> ClientSession:
|
||||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||||
self._session = await self._session_manager.create_session()
|
# Store the current task ID when initializing
|
||||||
|
self._creator_task_id = id(asyncio.current_task())
|
||||||
|
self._session, process = await self._session_manager.create_session()
|
||||||
|
# Store the process PID if available
|
||||||
|
if process and hasattr(process, "pid"):
|
||||||
|
self._process_pid = process.pid
|
||||||
|
self._initialized = True
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def _is_selected(
|
def _is_selected(
|
||||||
@ -114,10 +129,76 @@ class MCPToolset(BaseToolset):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Closes the connection to MCP Server."""
|
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
|
||||||
await self._exit_stack.aclose()
|
if not self._initialized:
|
||||||
|
return # Nothing to close
|
||||||
|
|
||||||
@retry_on_closed_resource('_initialize')
|
logger.info("Closing MCP Toolset")
|
||||||
|
|
||||||
|
# Step 1: Try graceful shutdown of the session if it exists
|
||||||
|
if self._session:
|
||||||
|
try:
|
||||||
|
logger.info("Attempting graceful session shutdown")
|
||||||
|
await self._session.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Session shutdown error (continuing cleanup): {e}")
|
||||||
|
|
||||||
|
# Step 2: Try to close the exit stack
|
||||||
|
try:
|
||||||
|
logger.info("Closing AsyncExitStack")
|
||||||
|
await self._exit_stack.aclose()
|
||||||
|
# If we get here, the exit stack closed successfully
|
||||||
|
logger.info("AsyncExitStack closed successfully")
|
||||||
|
return
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "Attempted to exit cancel scope in a different task" in str(e):
|
||||||
|
logger.warning("Task mismatch during shutdown - using fallback cleanup")
|
||||||
|
# Continue to manual cleanup
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected RuntimeError: {e}")
|
||||||
|
# Continue to manual cleanup
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during exit stack closure: {e}")
|
||||||
|
# Continue to manual cleanup
|
||||||
|
|
||||||
|
# Step 3: Manual cleanup of the subprocess if we have its PID
|
||||||
|
if self._process_pid:
|
||||||
|
await self._ensure_process_terminated(self._process_pid)
|
||||||
|
|
||||||
|
# Step 4: Ask the session manager to do any additional cleanup it can
|
||||||
|
await self._session_manager._emergency_cleanup()
|
||||||
|
|
||||||
|
async def _ensure_process_terminated(self, pid):
|
||||||
|
"""Ensure a process is terminated using its PID."""
|
||||||
|
try:
|
||||||
|
# Check if process exists
|
||||||
|
os.kill(pid, 0) # This just checks if the process exists
|
||||||
|
|
||||||
|
logger.info(f"Terminating process with PID {pid}")
|
||||||
|
# First try SIGTERM for graceful shutdown
|
||||||
|
os.kill(pid, signal.SIGTERM)
|
||||||
|
|
||||||
|
# Give it a moment to terminate
|
||||||
|
for _ in range(30): # wait up to 3 seconds
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0) # Process still exists
|
||||||
|
except ProcessLookupError:
|
||||||
|
logger.info(f"Process {pid} terminated successfully")
|
||||||
|
return
|
||||||
|
|
||||||
|
# If we get here, process didn't terminate gracefully
|
||||||
|
logger.warning(
|
||||||
|
f"Process {pid} didn't terminate gracefully, using SIGKILL"
|
||||||
|
)
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
|
||||||
|
except ProcessLookupError:
|
||||||
|
logger.info(f"Process {pid} already terminated")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error terminating process {pid}: {e}")
|
||||||
|
|
||||||
|
@retry_on_closed_resource("_initialize")
|
||||||
@override
|
@override
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user