diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 97a3f23..f0c8a30 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -15,14 +15,24 @@ from contextlib import AsyncExitStack import sys from types import TracebackType -from typing import List, Optional, TextIO, Tuple, Type +from typing import List +from typing import Optional +from typing import override +from typing import TextIO +from typing import Type -from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource +from ...agents.readonly_context import ReadonlyContext +from ..base_toolset import BaseToolPredicate +from ..base_toolset import BaseToolset +from .mcp_session_manager import MCPSessionManager +from .mcp_session_manager import retry_on_closed_resource +from .mcp_session_manager import SseServerParams # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. try: - from mcp import ClientSession, StdioServerParameters + from mcp import ClientSession + from mcp import StdioServerParameters from mcp.types import ListToolsResult except ImportError as e: import sys @@ -38,67 +48,20 @@ except ImportError as e: from .mcp_tool import MCPTool -class MCPToolset: +class MCPToolset(BaseToolset): """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. Usage: - Example 1: (using from_server helper): ``` - async def load_tools(): - return await MCPToolset.from_server( - connection_params=StdioServerParameters( - command='npx', - args=["-y", "@modelcontextprotocol/server-filesystem"], + root_agent = LlmAgent( + tools=MCPToolset( + connection_params=StdioServerParameters( + command='npx', + args=["-y", "@modelcontextprotocol/server-filesystem"], ) - ) - - # Use the tools in an LLM agent - tools, exit_stack = await load_tools() - agent = LlmAgent( - tools=tools - ) - ... - await exit_stack.aclose() - ``` - - Example 2: (using `async with`): - - ``` - async def load_tools(): - async with MCPToolset( - connection_params=SseServerParams(url="http://0.0.0.0:8090/sse") - ) as toolset: - tools = await toolset.load_tools() - - agent = LlmAgent( - ... - tools=tools ) + ) ``` - - Example 3: (provide AsyncExitStack): - ``` - async def load_tools(): - async_exit_stack = AsyncExitStack() - toolset = MCPToolset( - connection_params=StdioServerParameters(...), - ) - async_exit_stack.enter_async_context(toolset) - tools = await toolset.load_tools() - agent = LlmAgent( - ... - tools=tools - ) - ... - await async_exit_stack.aclose() - - ``` - - Attributes: - connection_params: The connection parameters to the MCP server. Can be - either `StdioServerParameters` or `SseServerParams`. - exit_stack: The async exit stack to manage the connection to the MCP server. - session: The MCP session being initialized with the connection. """ def __init__( @@ -106,140 +69,53 @@ class MCPToolset: *, connection_params: StdioServerParameters | SseServerParams, errlog: TextIO = sys.stderr, - exit_stack=AsyncExitStack(), + tool_predicate: Optional[BaseToolPredicate] = None, ): """Initializes the MCPToolset. - Usage: - Example 1: (using from_server helper): - ``` - async def load_tools(): - return await MCPToolset.from_server( - connection_params=StdioServerParameters( - command='npx', - args=["-y", "@modelcontextprotocol/server-filesystem"], - ) - ) - - # Use the tools in an LLM agent - tools, exit_stack = await load_tools() - agent = LlmAgent( - tools=tools - ) - ... - await exit_stack.aclose() - ``` - - Example 2: (using `async with`): - - ``` - async def load_tools(): - async with MCPToolset( - connection_params=SseServerParams(url="http://0.0.0.0:8090/sse") - ) as toolset: - tools = await toolset.load_tools() - - agent = LlmAgent( - ... - tools=tools - ) - ``` - - Example 3: (provide AsyncExitStack): - ``` - async def load_tools(): - async_exit_stack = AsyncExitStack() - toolset = MCPToolset( - connection_params=StdioServerParameters(...), - ) - async_exit_stack.enter_async_context(toolset) - tools = await toolset.load_tools() - agent = LlmAgent( - ... - tools=tools - ) - ... - await async_exit_stack.aclose() - - ``` - Args: connection_params: The connection parameters to the MCP server. Can be: `StdioServerParameters` for using local mcp server (e.g. using `npx` or `python3`); or `SseServerParams` for a local/remote SSE server. """ + if not connection_params: raise ValueError('Missing connection params in MCPToolset.') self.connection_params = connection_params self.errlog = errlog - self.exit_stack = exit_stack + self.exit_stack = AsyncExitStack() self.session_manager = MCPSessionManager( connection_params=self.connection_params, exit_stack=self.exit_stack, errlog=self.errlog, ) - - @classmethod - async def from_server( - cls, - *, - connection_params: StdioServerParameters | SseServerParams, - async_exit_stack: Optional[AsyncExitStack] = None, - errlog: TextIO = sys.stderr, - ) -> Tuple[List[MCPTool], AsyncExitStack]: - """Retrieve all tools from the MCP connection. - - Usage: - ``` - async def load_tools(): - tools, exit_stack = await MCPToolset.from_server( - connection_params=StdioServerParameters( - command='npx', - args=["-y", "@modelcontextprotocol/server-filesystem"], - ) - ) - ``` - - Args: - connection_params: The connection parameters to the MCP server. - async_exit_stack: The async exit stack to use. If not provided, a new - AsyncExitStack will be created. - - Returns: - A tuple of the list of MCPTools and the AsyncExitStack. - - tools: The list of MCPTools. - - async_exit_stack: The AsyncExitStack used to manage the connection to - the MCP server. Use `await async_exit_stack.aclose()` to close the - connection when server shuts down. - """ - async_exit_stack = async_exit_stack or AsyncExitStack() - toolset = cls( - connection_params=connection_params, - exit_stack=async_exit_stack, - errlog=errlog, - ) - - await async_exit_stack.enter_async_context(toolset) - tools = await toolset.load_tools() - return (tools, async_exit_stack) + self.session = None + self.tool_predicate = tool_predicate async def _initialize(self) -> ClientSession: """Connects to the MCP Server and initializes the ClientSession.""" self.session = await self.session_manager.create_session() return self.session - async def _exit(self): + @override + async def close(self): """Closes the connection to MCP Server.""" await self.exit_stack.aclose() @retry_on_closed_resource('_initialize') - async def load_tools(self) -> List[MCPTool]: + @override + async def get_tools( + self, + readony_context: ReadonlyContext = None, + ) -> List[MCPTool]: """Loads all tools from the MCP Server. Returns: A list of MCPTools imported from the MCP Server. """ + if not self.session: + await self._initialize() tools_response: ListToolsResult = await self.session.list_tools() return [ MCPTool( @@ -248,19 +124,6 @@ class MCPToolset: mcp_session_manager=self.session_manager, ) for tool in tools_response.tools + if self.tool_predicate is None + or self.tool_predicate(tool, readony_context) ] - - async def __aenter__(self): - try: - await self._initialize() - return self - except Exception as e: - raise e - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> None: - await self._exit()