mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 12:12:56 -06:00
refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners
PiperOrigin-RevId: 762463028
This commit is contained in:
parent
b9b2c3fb54
commit
92c37496d3
@ -16,12 +16,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@ -55,11 +52,9 @@ from starlette.types import Lifespan
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..agents import RunConfig
|
from ..agents import RunConfig
|
||||||
from ..agents.base_agent import BaseAgent
|
|
||||||
from ..agents.live_request_queue import LiveRequest
|
from ..agents.live_request_queue import LiveRequest
|
||||||
from ..agents.live_request_queue import LiveRequestQueue
|
from ..agents.live_request_queue import LiveRequestQueue
|
||||||
from ..agents.llm_agent import Agent
|
from ..agents.llm_agent import Agent
|
||||||
from ..agents.llm_agent import LlmAgent
|
|
||||||
from ..agents.run_config import StreamingMode
|
from ..agents.run_config import StreamingMode
|
||||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||||
from ..evaluation.eval_case import EvalCase
|
from ..evaluation.eval_case import EvalCase
|
||||||
@ -75,12 +70,12 @@ from ..sessions.session import Session
|
|||||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||||
from ..tools.base_toolset import BaseToolset
|
from ..tools.base_toolset import BaseToolset
|
||||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||||
from .cli_eval import EvalCaseResult
|
|
||||||
from .cli_eval import EvalMetric
|
from .cli_eval import EvalMetric
|
||||||
from .cli_eval import EvalMetricResult
|
from .cli_eval import EvalMetricResult
|
||||||
from .cli_eval import EvalMetricResultPerInvocation
|
from .cli_eval import EvalMetricResultPerInvocation
|
||||||
from .cli_eval import EvalSetResult
|
from .cli_eval import EvalSetResult
|
||||||
from .cli_eval import EvalStatus
|
from .cli_eval import EvalStatus
|
||||||
|
from .utils import cleanup
|
||||||
from .utils import common
|
from .utils import common
|
||||||
from .utils import create_empty_state
|
from .utils import create_empty_state
|
||||||
from .utils import envs
|
from .utils import envs
|
||||||
@ -230,27 +225,8 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
trace.set_tracer_provider(provider)
|
trace.set_tracer_provider(provider)
|
||||||
|
|
||||||
toolsets_to_close: set[BaseToolset] = set()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def internal_lifespan(app: FastAPI):
|
async def internal_lifespan(app: FastAPI):
|
||||||
# Set up signal handlers for graceful shutdown
|
|
||||||
original_sigterm = signal.getsignal(signal.SIGTERM)
|
|
||||||
original_sigint = signal.getsignal(signal.SIGINT)
|
|
||||||
|
|
||||||
def cleanup_handler(sig, frame):
|
|
||||||
# Log the signal
|
|
||||||
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
|
|
||||||
# Do synchronous cleanup if needed
|
|
||||||
# Then call original handler if it exists
|
|
||||||
if sig == signal.SIGTERM and callable(original_sigterm):
|
|
||||||
original_sigterm(sig, frame)
|
|
||||||
elif sig == signal.SIGINT and callable(original_sigint):
|
|
||||||
original_sigint(sig, frame)
|
|
||||||
|
|
||||||
# Install cleanup handlers
|
|
||||||
signal.signal(signal.SIGTERM, cleanup_handler)
|
|
||||||
signal.signal(signal.SIGINT, cleanup_handler)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if lifespan:
|
if lifespan:
|
||||||
@ -259,46 +235,8 @@ def get_fast_api_app(
|
|||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
# During shutdown, properly clean up all toolsets
|
# Create tasks for all runner closures to run concurrently
|
||||||
logger.info(
|
await cleanup.close_runners(list(runner_dict.values()))
|
||||||
"Server shutdown initiated, cleaning up %s toolsets",
|
|
||||||
len(toolsets_to_close),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create tasks for all toolset closures to run concurrently
|
|
||||||
cleanup_tasks = []
|
|
||||||
for toolset in toolsets_to_close:
|
|
||||||
task = asyncio.create_task(close_toolset_safely(toolset))
|
|
||||||
cleanup_tasks.append(task)
|
|
||||||
|
|
||||||
if cleanup_tasks:
|
|
||||||
# Wait for all cleanup tasks with timeout
|
|
||||||
done, pending = await asyncio.wait(
|
|
||||||
cleanup_tasks,
|
|
||||||
timeout=10.0, # 10 second timeout for cleanup
|
|
||||||
return_when=asyncio.ALL_COMPLETED,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If any tasks are still pending, log it
|
|
||||||
if pending:
|
|
||||||
logger.warning(
|
|
||||||
f"{len(pending)} toolset cleanup tasks didn't complete in time"
|
|
||||||
)
|
|
||||||
for task in pending:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
# Restore original signal handlers
|
|
||||||
signal.signal(signal.SIGTERM, original_sigterm)
|
|
||||||
signal.signal(signal.SIGINT, original_sigint)
|
|
||||||
|
|
||||||
async def close_toolset_safely(toolset):
|
|
||||||
"""Safely close a toolset with error handling."""
|
|
||||||
try:
|
|
||||||
logger.info(f"Closing toolset: {type(toolset).__name__}")
|
|
||||||
await toolset.close()
|
|
||||||
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
|
|
||||||
|
|
||||||
# Run the FastAPI server.
|
# Run the FastAPI server.
|
||||||
app = FastAPI(lifespan=internal_lifespan)
|
app = FastAPI(lifespan=internal_lifespan)
|
||||||
@ -903,16 +841,6 @@ def get_fast_api_app(
|
|||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
|
|
||||||
toolsets = set()
|
|
||||||
if isinstance(agent, LlmAgent):
|
|
||||||
for tool_union in agent.tools:
|
|
||||||
if isinstance(tool_union, BaseToolset):
|
|
||||||
toolsets.add(tool_union)
|
|
||||||
for sub_agent in agent.sub_agents:
|
|
||||||
toolsets.update(_get_all_toolsets(sub_agent))
|
|
||||||
return toolsets
|
|
||||||
|
|
||||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||||
"""Returns the root agent for the given app."""
|
"""Returns the root agent for the given app."""
|
||||||
if app_name in root_agent_dict:
|
if app_name in root_agent_dict:
|
||||||
@ -924,7 +852,6 @@ def get_fast_api_app(
|
|||||||
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||||
|
|
||||||
root_agent_dict[app_name] = root_agent
|
root_agent_dict[app_name] = root_agent
|
||||||
toolsets_to_close.update(_get_all_toolsets(root_agent))
|
|
||||||
return root_agent
|
return root_agent
|
||||||
|
|
||||||
async def _get_runner_async(app_name: str) -> Runner:
|
async def _get_runner_async(app_name: str) -> Runner:
|
||||||
|
40
src/google/adk/cli/utils/cleanup.py
Normal file
40
src/google/adk/cli/utils/cleanup.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from ...runners import Runner
|
||||||
|
|
||||||
|
logger = logging.getLogger("google_adk." + __name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_runners(runners: List[Runner]) -> None:
|
||||||
|
cleanup_tasks = [asyncio.create_task(runner.close()) for runner in runners]
|
||||||
|
if cleanup_tasks:
|
||||||
|
# Wait for all cleanup tasks with timeout
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
cleanup_tasks,
|
||||||
|
timeout=30.0, # 30 second timeout for cleanup
|
||||||
|
return_when=asyncio.ALL_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If any tasks are still pending, log it
|
||||||
|
if pending:
|
||||||
|
logger.warning(
|
||||||
|
"%s runner close tasks didn't complete in time", len(pending)
|
||||||
|
)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
@ -42,6 +42,7 @@ from .sessions.base_session_service import BaseSessionService
|
|||||||
from .sessions.in_memory_session_service import InMemorySessionService
|
from .sessions.in_memory_session_service import InMemorySessionService
|
||||||
from .sessions.session import Session
|
from .sessions.session import Session
|
||||||
from .telemetry import tracer
|
from .telemetry import tracer
|
||||||
|
from .tools.base_toolset import BaseToolset
|
||||||
|
|
||||||
logger = logging.getLogger('google_adk.' + __name__)
|
logger = logging.getLogger('google_adk.' + __name__)
|
||||||
|
|
||||||
@ -457,6 +458,37 @@ class Runner:
|
|||||||
run_config=run_config,
|
run_config=run_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
|
||||||
|
toolsets = set()
|
||||||
|
if isinstance(agent, LlmAgent):
|
||||||
|
for tool_union in agent.tools:
|
||||||
|
if isinstance(tool_union, BaseToolset):
|
||||||
|
toolsets.add(tool_union)
|
||||||
|
for sub_agent in agent.sub_agents:
|
||||||
|
toolsets.update(self._collect_toolset(sub_agent))
|
||||||
|
return toolsets
|
||||||
|
|
||||||
|
async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
|
||||||
|
"""Clean up toolsets with proper task context management."""
|
||||||
|
if not toolsets_to_close:
|
||||||
|
return
|
||||||
|
|
||||||
|
# This maintains the same task context throughout cleanup
|
||||||
|
for toolset in toolsets_to_close:
|
||||||
|
try:
|
||||||
|
logger.info('Closing toolset: %s', type(toolset).__name__)
|
||||||
|
# Use asyncio.wait_for to add timeout protection
|
||||||
|
await asyncio.wait_for(toolset.close(), timeout=10.0)
|
||||||
|
logger.info('Successfully closed toolset: %s', type(toolset).__name__)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning('Toolset %s cleanup timed out', type(toolset).__name__)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Closes the runner."""
|
||||||
|
await self._cleanup_toolsets(self._collect_toolset(self.agent))
|
||||||
|
|
||||||
|
|
||||||
class InMemoryRunner(Runner):
|
class InMemoryRunner(Runner):
|
||||||
"""An in-memory Runner for testing and development.
|
"""An in-memory Runner for testing and development.
|
||||||
|
@ -12,8 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@ -87,13 +86,11 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@functools.wraps(
|
@functools.wraps(func) # Preserves original function metadata
|
||||||
func
|
|
||||||
) # Preserves original function metadata (name, docstring)
|
|
||||||
async def wrapper(self, *args, **kwargs):
|
async def wrapper(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return await func(self, *args, **kwargs)
|
return await func(self, *args, **kwargs)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError as close_err:
|
||||||
try:
|
try:
|
||||||
if hasattr(self, async_reinit_func_name) and callable(
|
if hasattr(self, async_reinit_func_name) and callable(
|
||||||
getattr(self, async_reinit_func_name)
|
getattr(self, async_reinit_func_name)
|
||||||
@ -105,7 +102,7 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|||||||
f'Function {async_reinit_func_name} does not exist in decorated'
|
f'Function {async_reinit_func_name} does not exist in decorated'
|
||||||
' class. Please check the function name in'
|
' class. Please check the function name in'
|
||||||
' retry_on_closed_resource decorator.'
|
' retry_on_closed_resource decorator.'
|
||||||
)
|
) from close_err
|
||||||
except Exception as reinit_err:
|
except Exception as reinit_err:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'Error reinitializing: {reinit_err}'
|
f'Error reinitializing: {reinit_err}'
|
||||||
@ -117,45 +114,6 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def tracked_stdio_client(server, errlog, process=None):
|
|
||||||
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
|
|
||||||
our_process = process
|
|
||||||
|
|
||||||
# If no process was provided, create one
|
|
||||||
if our_process is None:
|
|
||||||
our_process = await asyncio.create_subprocess_exec(
|
|
||||||
server.command,
|
|
||||||
*server.args,
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=errlog,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the original stdio_client, but ensure process cleanup
|
|
||||||
try:
|
|
||||||
async with stdio_client(server=server, errlog=errlog) as client:
|
|
||||||
yield client, our_process
|
|
||||||
finally:
|
|
||||||
# Ensure the process is properly terminated if it still exists
|
|
||||||
if our_process and our_process.returncode is None:
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f'Terminating process {our_process.pid} from tracked_stdio_client'
|
|
||||||
)
|
|
||||||
our_process.terminate()
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(our_process.wait(), timeout=3.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Force kill if it doesn't terminate quickly
|
|
||||||
if our_process.returncode is None:
|
|
||||||
logger.warning(f'Forcing kill of process {our_process.pid}')
|
|
||||||
our_process.kill()
|
|
||||||
except ProcessLookupError:
|
|
||||||
# Process already gone, that's fine
|
|
||||||
logger.info(f'Process {our_process.pid} already terminated')
|
|
||||||
|
|
||||||
|
|
||||||
class MCPSessionManager:
|
class MCPSessionManager:
|
||||||
"""Manages MCP client sessions.
|
"""Manages MCP client sessions.
|
||||||
|
|
||||||
@ -166,162 +124,78 @@ class MCPSessionManager:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connection_params: StdioServerParameters | SseServerParams,
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
exit_stack: AsyncExitStack,
|
|
||||||
errlog: TextIO = sys.stderr,
|
errlog: TextIO = sys.stderr,
|
||||||
):
|
):
|
||||||
"""Initializes the MCP session manager.
|
"""Initializes the MCP session manager.
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```
|
|
||||||
mcp_session_manager = MCPSessionManager(
|
|
||||||
connection_params=connection_params,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
)
|
|
||||||
session = await mcp_session_manager.create_session()
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
|
||||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||||
initializing a local stdio MCP session.
|
initializing a local stdio MCP session.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._connection_params = connection_params
|
self._connection_params = connection_params
|
||||||
self._exit_stack = exit_stack
|
|
||||||
self._errlog = errlog
|
self._errlog = errlog
|
||||||
self._process = None # Track the subprocess
|
# Each session manager maintains its own exit stack for proper cleanup
|
||||||
self._active_processes = set() # Track all processes created
|
self._exit_stack: Optional[AsyncExitStack] = None
|
||||||
self._active_file_handles = set() # Track file handles
|
self._session: Optional[ClientSession] = None
|
||||||
|
|
||||||
async def create_session(
|
async def create_session(self) -> ClientSession:
|
||||||
self,
|
"""Creates and initializes an MCP client session.
|
||||||
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
|
||||||
"""Creates a new MCP session and tracks the associated process."""
|
|
||||||
session, process = await self._initialize_session(
|
|
||||||
connection_params=self._connection_params,
|
|
||||||
exit_stack=self._exit_stack,
|
|
||||||
errlog=self._errlog,
|
|
||||||
)
|
|
||||||
self._process = process # Store reference to process
|
|
||||||
|
|
||||||
# Track the process
|
|
||||||
if process:
|
|
||||||
self._active_processes.add(process)
|
|
||||||
|
|
||||||
return session, process
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _initialize_session(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
connection_params: StdioServerParameters | SseServerParams,
|
|
||||||
exit_stack: AsyncExitStack,
|
|
||||||
errlog: TextIO = sys.stderr,
|
|
||||||
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
|
||||||
"""Initializes an MCP client session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
|
||||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
|
||||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
|
||||||
initializing a local stdio MCP session.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ClientSession: The initialized MCP client session.
|
ClientSession: The initialized MCP client session.
|
||||||
"""
|
"""
|
||||||
process = None
|
if self._session is not None:
|
||||||
|
return self._session
|
||||||
|
|
||||||
if isinstance(connection_params, StdioServerParameters):
|
# Create a new exit stack for this session
|
||||||
# For stdio connections, we need to track the subprocess
|
self._exit_stack = AsyncExitStack()
|
||||||
client, process = await cls._create_stdio_client(
|
|
||||||
server=connection_params,
|
try:
|
||||||
errlog=errlog,
|
if isinstance(self._connection_params, StdioServerParameters):
|
||||||
exit_stack=exit_stack,
|
client = stdio_client(
|
||||||
|
server=self._connection_params, errlog=self._errlog
|
||||||
)
|
)
|
||||||
elif isinstance(connection_params, SseServerParams):
|
elif isinstance(self._connection_params, SseServerParams):
|
||||||
# For SSE connections, create the client without a subprocess
|
|
||||||
client = sse_client(
|
client = sse_client(
|
||||||
url=connection_params.url,
|
url=self._connection_params.url,
|
||||||
headers=connection_params.headers,
|
headers=self._connection_params.headers,
|
||||||
timeout=connection_params.timeout,
|
timeout=self._connection_params.timeout,
|
||||||
sse_read_timeout=connection_params.sse_read_timeout,
|
sse_read_timeout=self._connection_params.sse_read_timeout,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Unable to initialize connection. Connection should be'
|
'Unable to initialize connection. Connection should be'
|
||||||
' StdioServerParameters or SseServerParams, but got'
|
' StdioServerParameters or SseServerParams, but got'
|
||||||
f' {connection_params}'
|
f' {self._connection_params}'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the session with the client
|
transports = await self._exit_stack.enter_async_context(client)
|
||||||
transports = await exit_stack.enter_async_context(client)
|
session = await self._exit_stack.enter_async_context(
|
||||||
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
ClientSession(*transports)
|
||||||
|
)
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
return session, process
|
self._session = session
|
||||||
|
return session
|
||||||
|
|
||||||
@staticmethod
|
except Exception:
|
||||||
async def _create_stdio_client(
|
# If session creation fails, clean up the exit stack
|
||||||
server: StdioServerParameters,
|
if self._exit_stack:
|
||||||
errlog: TextIO,
|
await self._exit_stack.aclose()
|
||||||
exit_stack: AsyncExitStack,
|
self._exit_stack = None
|
||||||
) -> tuple[Any, asyncio.subprocess.Process]:
|
raise
|
||||||
"""Create stdio client and return both the client and process.
|
|
||||||
|
|
||||||
This implementation adapts to how the MCP stdio_client is created.
|
async def close(self):
|
||||||
The actual implementation may need to be adjusted based on the MCP library
|
"""Closes the session and cleans up resources."""
|
||||||
structure.
|
if self._exit_stack:
|
||||||
"""
|
|
||||||
# Create the subprocess directly so we can track it
|
|
||||||
process = await asyncio.create_subprocess_exec(
|
|
||||||
server.command,
|
|
||||||
*server.args,
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=errlog,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the stdio client using the MCP library
|
|
||||||
try:
|
try:
|
||||||
# Method 1: Try using the existing process if stdio_client supports it
|
await self._exit_stack.aclose()
|
||||||
client = stdio_client(server=server, errlog=errlog, process=process)
|
|
||||||
except TypeError:
|
|
||||||
# Method 2: If the above doesn't work, let stdio_client create its own process
|
|
||||||
# and we'll need to terminate both processes later
|
|
||||||
logger.warning(
|
|
||||||
'Using stdio_client with its own process - may lead to duplicate'
|
|
||||||
' processes'
|
|
||||||
)
|
|
||||||
client = stdio_client(server=server, errlog=errlog)
|
|
||||||
|
|
||||||
return client, process
|
|
||||||
|
|
||||||
async def _emergency_cleanup(self):
|
|
||||||
"""Perform emergency cleanup of resources when normal cleanup fails."""
|
|
||||||
logger.info('Performing emergency cleanup of MCPSessionManager resources')
|
|
||||||
|
|
||||||
# Clean up any tracked processes
|
|
||||||
for proc in list(self._active_processes):
|
|
||||||
try:
|
|
||||||
if proc and proc.returncode is None:
|
|
||||||
logger.info(f'Emergency termination of process {proc.pid}')
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(proc.wait(), timeout=1.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
|
|
||||||
proc.kill()
|
|
||||||
self._active_processes.remove(proc)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Error during process cleanup: {e}')
|
# Log the error but don't re-raise to avoid blocking shutdown
|
||||||
|
print(
|
||||||
# Clean up any tracked file handles
|
f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
|
||||||
for handle in list(self._active_file_handles):
|
)
|
||||||
try:
|
finally:
|
||||||
if not handle.closed:
|
self._exit_stack = None
|
||||||
logger.info('Closing file handle')
|
self._session = None
|
||||||
handle.close()
|
|
||||||
self._active_file_handles.remove(handle)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error closing file handle: {e}')
|
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from google.genai.types import FunctionDeclaration
|
from google.genai.types import FunctionDeclaration
|
||||||
@ -23,7 +25,6 @@ from .mcp_session_manager import retry_on_closed_resource
|
|||||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||||
# their Python version to 3.10 if it fails.
|
# their Python version to 3.10 if it fails.
|
||||||
try:
|
try:
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp.types import Tool as McpBaseTool
|
from mcp.types import Tool as McpBaseTool
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
import sys
|
import sys
|
||||||
@ -43,6 +44,8 @@ from ..base_tool import BaseTool
|
|||||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||||
from ..tool_context import ToolContext
|
from ..tool_context import ToolContext
|
||||||
|
|
||||||
|
logger = logging.getLogger("google_adk." + __name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPTool(BaseTool):
|
class MCPTool(BaseTool):
|
||||||
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
|
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
|
||||||
@ -53,44 +56,40 @@ class MCPTool(BaseTool):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
mcp_tool: McpBaseTool,
|
mcp_tool: McpBaseTool,
|
||||||
mcp_session: ClientSession,
|
|
||||||
mcp_session_manager: MCPSessionManager,
|
mcp_session_manager: MCPSessionManager,
|
||||||
auth_scheme: Optional[AuthScheme] = None,
|
auth_scheme: Optional[AuthScheme] = None,
|
||||||
auth_credential: Optional[AuthCredential] | None = None,
|
auth_credential: Optional[AuthCredential] = None,
|
||||||
):
|
):
|
||||||
"""Initializes a MCPTool.
|
"""Initializes a MCPTool.
|
||||||
|
|
||||||
This tool wraps a MCP Tool interface and an active MCP Session. It invokes
|
This tool wraps a MCP Tool interface and uses a session manager to
|
||||||
the MCP Tool through executing the tool from remote MCP Session.
|
communicate with the MCP server.
|
||||||
|
|
||||||
Example:
|
|
||||||
tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mcp_tool: The MCP tool to wrap.
|
mcp_tool: The MCP tool to wrap.
|
||||||
mcp_session: The MCP session to use to call the tool.
|
mcp_session_manager: The MCP session manager to use for communication.
|
||||||
auth_scheme: The authentication scheme to use.
|
auth_scheme: The authentication scheme to use.
|
||||||
auth_credential: The authentication credential to use.
|
auth_credential: The authentication credential to use.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If mcp_tool or mcp_session is None.
|
ValueError: If mcp_tool or mcp_session_manager is None.
|
||||||
"""
|
"""
|
||||||
if mcp_tool is None:
|
if mcp_tool is None:
|
||||||
raise ValueError("mcp_tool cannot be None")
|
raise ValueError("mcp_tool cannot be None")
|
||||||
if mcp_session is None:
|
if mcp_session_manager is None:
|
||||||
raise ValueError("mcp_session cannot be None")
|
raise ValueError("mcp_session_manager cannot be None")
|
||||||
super().__init__(name=mcp_tool.name, description=mcp_tool.description or "")
|
super().__init__(
|
||||||
|
name=mcp_tool.name,
|
||||||
|
description=mcp_tool.description if mcp_tool.description else "",
|
||||||
|
)
|
||||||
self._mcp_tool = mcp_tool
|
self._mcp_tool = mcp_tool
|
||||||
self._mcp_session = mcp_session
|
|
||||||
self._mcp_session_manager = mcp_session_manager
|
self._mcp_session_manager = mcp_session_manager
|
||||||
# TODO(cheliu): Support passing auth to MCP Server.
|
# TODO(cheliu): Support passing auth to MCP Server.
|
||||||
self._auth_scheme = auth_scheme
|
self._auth_scheme = auth_scheme
|
||||||
self._auth_credential = auth_credential
|
self._auth_credential = auth_credential
|
||||||
|
|
||||||
async def _reinitialize_session(self):
|
|
||||||
self._mcp_session = await self._mcp_session_manager.create_session()
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_declaration(self) -> FunctionDeclaration:
|
def _get_declaration(self) -> FunctionDeclaration:
|
||||||
"""Gets the function declaration for the tool.
|
"""Gets the function declaration for the tool.
|
||||||
@ -105,7 +104,6 @@ class MCPTool(BaseTool):
|
|||||||
)
|
)
|
||||||
return function_decl
|
return function_decl
|
||||||
|
|
||||||
@override
|
|
||||||
@retry_on_closed_resource("_reinitialize_session")
|
@retry_on_closed_resource("_reinitialize_session")
|
||||||
async def run_async(self, *, args, tool_context: ToolContext):
|
async def run_async(self, *, args, tool_context: ToolContext):
|
||||||
"""Runs the tool asynchronously.
|
"""Runs the tool asynchronously.
|
||||||
@ -117,10 +115,15 @@ class MCPTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
Any: The response from the tool.
|
Any: The response from the tool.
|
||||||
"""
|
"""
|
||||||
|
# Get the session from the session manager
|
||||||
|
session = await self._mcp_session_manager.create_session()
|
||||||
|
|
||||||
# TODO(cheliu): Support passing tool context to MCP Server.
|
# TODO(cheliu): Support passing tool context to MCP Server.
|
||||||
try:
|
response = await session.call_tool(self.name, arguments=args)
|
||||||
response = await self._mcp_session.call_tool(self.name, arguments=args)
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
async def _reinitialize_session(self):
|
||||||
raise e
|
"""Reinitializes the session when connection is lost."""
|
||||||
|
# Close the old session and create a new one
|
||||||
|
await self._mcp_session_manager.close()
|
||||||
|
await self._mcp_session_manager.create_session()
|
||||||
|
@ -12,19 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from contextlib import AsyncExitStack
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from ...agents.readonly_context import ReadonlyContext
|
from ...agents.readonly_context import ReadonlyContext
|
||||||
from ..base_tool import BaseTool
|
from ..base_tool import BaseTool
|
||||||
from ..base_toolset import BaseToolset
|
from ..base_toolset import BaseToolset
|
||||||
@ -36,7 +30,6 @@ from .mcp_session_manager import SseServerParams
|
|||||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||||
# their Python version to 3.10 if it fails.
|
# their Python version to 3.10 if it fails.
|
||||||
try:
|
try:
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp import StdioServerParameters
|
from mcp import StdioServerParameters
|
||||||
from mcp.types import ListToolsResult
|
from mcp.types import ListToolsResult
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -58,16 +51,31 @@ logger = logging.getLogger("google_adk." + __name__)
|
|||||||
class MCPToolset(BaseToolset):
|
class MCPToolset(BaseToolset):
|
||||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||||
|
|
||||||
|
This toolset manages the connection to an MCP server and provides tools
|
||||||
|
that can be used by an agent. It properly implements the BaseToolset
|
||||||
|
interface for easy integration with the agent framework.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```
|
```python
|
||||||
root_agent = LlmAgent(
|
toolset = MCPToolset(
|
||||||
tools=MCPToolset(
|
|
||||||
connection_params=StdioServerParameters(
|
connection_params=StdioServerParameters(
|
||||||
command='npx',
|
command='npx',
|
||||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||||
|
),
|
||||||
|
tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use in an agent
|
||||||
|
agent = LlmAgent(
|
||||||
|
model='gemini-2.0-flash',
|
||||||
|
name='enterprise_assistant',
|
||||||
|
instruction='Help user accessing their file systems',
|
||||||
|
tools=[toolset],
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# Cleanup is handled automatically by the agent framework
|
||||||
|
# But you can also manually close if needed:
|
||||||
|
# await toolset.close()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -75,8 +83,8 @@ class MCPToolset(BaseToolset):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
connection_params: StdioServerParameters | SseServerParams,
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
errlog: TextIO = sys.stderr,
|
|
||||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||||
|
errlog: TextIO = sys.stderr,
|
||||||
):
|
):
|
||||||
"""Initializes the MCPToolset.
|
"""Initializes the MCPToolset.
|
||||||
|
|
||||||
@ -84,131 +92,80 @@ class MCPToolset(BaseToolset):
|
|||||||
connection_params: The connection parameters to the MCP server. Can be:
|
connection_params: The connection parameters to the MCP server. Can be:
|
||||||
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
|
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
|
||||||
`python3`); or `SseServerParams` for a local/remote SSE server.
|
`python3`); or `SseServerParams` for a local/remote SSE server.
|
||||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
tool_filter: Optional filter to select specific tools. Can be either:
|
||||||
initializing a local stdio MCP session.
|
- A list of tool names to include
|
||||||
|
- A ToolPredicate function for custom filtering logic
|
||||||
|
errlog: TextIO stream for error logging.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(tool_filter=tool_filter)
|
||||||
|
|
||||||
if not connection_params:
|
if not connection_params:
|
||||||
raise ValueError("Missing connection params in MCPToolset.")
|
raise ValueError("Missing connection params in MCPToolset.")
|
||||||
super().__init__(tool_filter=tool_filter)
|
|
||||||
self._connection_params = connection_params
|
self._connection_params = connection_params
|
||||||
self._errlog = errlog
|
self._errlog = errlog
|
||||||
self._exit_stack = AsyncExitStack()
|
|
||||||
self._creator_task_id = None
|
|
||||||
self._process_pid = None # Store the subprocess PID
|
|
||||||
|
|
||||||
self._session_manager = MCPSessionManager(
|
# Create the session manager that will handle the MCP connection
|
||||||
|
self._mcp_session_manager = MCPSessionManager(
|
||||||
connection_params=self._connection_params,
|
connection_params=self._connection_params,
|
||||||
exit_stack=self._exit_stack,
|
|
||||||
errlog=self._errlog,
|
errlog=self._errlog,
|
||||||
)
|
)
|
||||||
self._session = None
|
self._session = None
|
||||||
self._initialized = False
|
|
||||||
|
|
||||||
async def _initialize(self) -> ClientSession:
|
@retry_on_closed_resource("_reinitialize_session")
|
||||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
|
||||||
# Store the current task ID when initializing
|
|
||||||
self._creator_task_id = id(asyncio.current_task())
|
|
||||||
self._session, process = await self._session_manager.create_session()
|
|
||||||
# Store the process PID if available
|
|
||||||
if process and hasattr(process, "pid"):
|
|
||||||
self._process_pid = process.pid
|
|
||||||
self._initialized = True
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def close(self):
|
|
||||||
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
|
|
||||||
if not self._initialized:
|
|
||||||
return # Nothing to close
|
|
||||||
|
|
||||||
logger.info("Closing MCP Toolset")
|
|
||||||
|
|
||||||
# Step 1: Try graceful shutdown of the session if it exists
|
|
||||||
if self._session:
|
|
||||||
try:
|
|
||||||
logger.info("Attempting graceful session shutdown")
|
|
||||||
await self._session.shutdown()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Session shutdown error (continuing cleanup): {e}")
|
|
||||||
|
|
||||||
# Step 2: Try to close the exit stack
|
|
||||||
try:
|
|
||||||
logger.info("Closing AsyncExitStack")
|
|
||||||
await self._exit_stack.aclose()
|
|
||||||
# If we get here, the exit stack closed successfully
|
|
||||||
logger.info("AsyncExitStack closed successfully")
|
|
||||||
return
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "Attempted to exit cancel scope in a different task" in str(e):
|
|
||||||
logger.warning("Task mismatch during shutdown - using fallback cleanup")
|
|
||||||
# Continue to manual cleanup
|
|
||||||
else:
|
|
||||||
logger.error(f"Unexpected RuntimeError: {e}")
|
|
||||||
# Continue to manual cleanup
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during exit stack closure: {e}")
|
|
||||||
# Continue to manual cleanup
|
|
||||||
|
|
||||||
# Step 3: Manual cleanup of the subprocess if we have its PID
|
|
||||||
if self._process_pid:
|
|
||||||
await self._ensure_process_terminated(self._process_pid)
|
|
||||||
|
|
||||||
# Step 4: Ask the session manager to do any additional cleanup it can
|
|
||||||
await self._session_manager._emergency_cleanup()
|
|
||||||
|
|
||||||
async def _ensure_process_terminated(self, pid):
|
|
||||||
"""Ensure a process is terminated using its PID."""
|
|
||||||
try:
|
|
||||||
# Check if process exists
|
|
||||||
os.kill(pid, 0) # This just checks if the process exists
|
|
||||||
|
|
||||||
logger.info(f"Terminating process with PID {pid}")
|
|
||||||
# First try SIGTERM for graceful shutdown
|
|
||||||
os.kill(pid, signal.SIGTERM)
|
|
||||||
|
|
||||||
# Give it a moment to terminate
|
|
||||||
for _ in range(30): # wait up to 3 seconds
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
try:
|
|
||||||
os.kill(pid, 0) # Process still exists
|
|
||||||
except ProcessLookupError:
|
|
||||||
logger.info(f"Process {pid} terminated successfully")
|
|
||||||
return
|
|
||||||
|
|
||||||
# If we get here, process didn't terminate gracefully
|
|
||||||
logger.warning(
|
|
||||||
f"Process {pid} didn't terminate gracefully, using SIGKILL"
|
|
||||||
)
|
|
||||||
os.kill(pid, signal.SIGKILL)
|
|
||||||
|
|
||||||
except ProcessLookupError:
|
|
||||||
logger.info(f"Process {pid} already terminated")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error terminating process {pid}: {e}")
|
|
||||||
|
|
||||||
@retry_on_closed_resource("_initialize")
|
|
||||||
@override
|
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
self,
|
self,
|
||||||
readonly_context: Optional[ReadonlyContext] = None,
|
readonly_context: Optional[ReadonlyContext] = None,
|
||||||
) -> List[MCPTool]:
|
) -> List[BaseTool]:
|
||||||
"""Loads all tools from the MCP Server.
|
"""Return all tools in the toolset based on the provided context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
readonly_context: Context used to filter tools available to the agent.
|
||||||
|
If None, all tools in the toolset are returned.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of MCPTools imported from the MCP Server.
|
List[BaseTool]: A list of tools available under the specified context.
|
||||||
"""
|
"""
|
||||||
|
# Get session from session manager
|
||||||
if not self._session:
|
if not self._session:
|
||||||
await self._initialize()
|
self._session = await self._mcp_session_manager.create_session()
|
||||||
|
|
||||||
|
# Fetch available tools from the MCP server
|
||||||
tools_response: ListToolsResult = await self._session.list_tools()
|
tools_response: ListToolsResult = await self._session.list_tools()
|
||||||
|
|
||||||
|
# Apply filtering based on context and tool_filter
|
||||||
tools = []
|
tools = []
|
||||||
for tool in tools_response.tools:
|
for tool in tools_response.tools:
|
||||||
mcp_tool = MCPTool(
|
mcp_tool = MCPTool(
|
||||||
mcp_tool=tool,
|
mcp_tool=tool,
|
||||||
mcp_session=self._session,
|
mcp_session_manager=self._mcp_session_manager,
|
||||||
mcp_session_manager=self._session_manager,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._is_tool_selected(mcp_tool, readonly_context):
|
if self._is_tool_selected(mcp_tool, readonly_context):
|
||||||
tools.append(mcp_tool)
|
tools.append(mcp_tool)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
async def _reinitialize_session(self):
|
||||||
|
"""Reinitializes the session when connection is lost."""
|
||||||
|
# Close the old session and clear cache
|
||||||
|
await self._mcp_session_manager.close()
|
||||||
|
self._session = await self._mcp_session_manager.create_session()
|
||||||
|
|
||||||
|
# Tools will be reloaded on next get_tools call
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Performs cleanup and releases resources held by the toolset.
|
||||||
|
|
||||||
|
This method closes the MCP session and cleans up all associated resources.
|
||||||
|
It's designed to be safe to call multiple times and handles cleanup errors
|
||||||
|
gracefully to avoid blocking application shutdown.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self._mcp_session_manager.close()
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but don't re-raise to avoid blocking shutdown
|
||||||
|
print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog)
|
||||||
|
finally:
|
||||||
|
# Clear cached tools
|
||||||
|
self._tools_cache = None
|
||||||
|
self._tools_loaded = False
|
||||||
|
Loading…
Reference in New Issue
Block a user