adapt MCP toolset to the latest Toolset interface

PiperOrigin-RevId: 756611140
This commit is contained in:
Xiang (Sean) Zhou 2025-05-08 22:47:11 -07:00 committed by Copybara-Service
parent 4d7298e4f2
commit 7dffb59096

View File

@ -15,14 +15,24 @@
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import sys import sys
from types import TracebackType from types import TracebackType
from typing import List, Optional, TextIO, Tuple, Type from typing import List
from typing import Optional
from typing import override
from typing import TextIO
from typing import Type
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource from ...agents.readonly_context import ReadonlyContext
from ..base_toolset import BaseToolPredicate
from ..base_toolset import BaseToolset
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_closed_resource
from .mcp_session_manager import SseServerParams
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade # Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails. # their Python version to 3.10 if it fails.
try: try:
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession
from mcp import StdioServerParameters
from mcp.types import ListToolsResult from mcp.types import ListToolsResult
except ImportError as e: except ImportError as e:
import sys import sys
@ -38,67 +48,20 @@ except ImportError as e:
from .mcp_tool import MCPTool from .mcp_tool import MCPTool
class MCPToolset: class MCPToolset(BaseToolset):
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
Usage: Usage:
Example 1: (using from_server helper):
``` ```
async def load_tools(): root_agent = LlmAgent(
return await MCPToolset.from_server( tools=MCPToolset(
connection_params=StdioServerParameters( connection_params=StdioServerParameters(
command='npx', command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"], args=["-y", "@modelcontextprotocol/server-filesystem"],
) )
)
# Use the tools in an LLM agent
tools, exit_stack = await load_tools()
agent = LlmAgent(
tools=tools
)
...
await exit_stack.aclose()
```
Example 2: (using `async with`):
```
async def load_tools():
async with MCPToolset(
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
) as toolset:
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
) )
)
``` ```
Example 3: (provide AsyncExitStack):
```
async def load_tools():
async_exit_stack = AsyncExitStack()
toolset = MCPToolset(
connection_params=StdioServerParameters(...),
)
async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
...
await async_exit_stack.aclose()
```
Attributes:
connection_params: The connection parameters to the MCP server. Can be
either `StdioServerParameters` or `SseServerParams`.
exit_stack: The async exit stack to manage the connection to the MCP server.
session: The MCP session being initialized with the connection.
""" """
def __init__( def __init__(
@ -106,140 +69,53 @@ class MCPToolset:
*, *,
connection_params: StdioServerParameters | SseServerParams, connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr, errlog: TextIO = sys.stderr,
exit_stack=AsyncExitStack(), tool_predicate: Optional[BaseToolPredicate] = None,
): ):
"""Initializes the MCPToolset. """Initializes the MCPToolset.
Usage:
Example 1: (using from_server helper):
```
async def load_tools():
return await MCPToolset.from_server(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
)
# Use the tools in an LLM agent
tools, exit_stack = await load_tools()
agent = LlmAgent(
tools=tools
)
...
await exit_stack.aclose()
```
Example 2: (using `async with`):
```
async def load_tools():
async with MCPToolset(
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
) as toolset:
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
```
Example 3: (provide AsyncExitStack):
```
async def load_tools():
async_exit_stack = AsyncExitStack()
toolset = MCPToolset(
connection_params=StdioServerParameters(...),
)
async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
...
await async_exit_stack.aclose()
```
Args: Args:
connection_params: The connection parameters to the MCP server. Can be: connection_params: The connection parameters to the MCP server. Can be:
`StdioServerParameters` for using local mcp server (e.g. using `npx` or `StdioServerParameters` for using local mcp server (e.g. using `npx` or
`python3`); or `SseServerParams` for a local/remote SSE server. `python3`); or `SseServerParams` for a local/remote SSE server.
""" """
if not connection_params: if not connection_params:
raise ValueError('Missing connection params in MCPToolset.') raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params self.connection_params = connection_params
self.errlog = errlog self.errlog = errlog
self.exit_stack = exit_stack self.exit_stack = AsyncExitStack()
self.session_manager = MCPSessionManager( self.session_manager = MCPSessionManager(
connection_params=self.connection_params, connection_params=self.connection_params,
exit_stack=self.exit_stack, exit_stack=self.exit_stack,
errlog=self.errlog, errlog=self.errlog,
) )
self.session = None
@classmethod self.tool_predicate = tool_predicate
async def from_server(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
async_exit_stack: Optional[AsyncExitStack] = None,
errlog: TextIO = sys.stderr,
) -> Tuple[List[MCPTool], AsyncExitStack]:
"""Retrieve all tools from the MCP connection.
Usage:
```
async def load_tools():
tools, exit_stack = await MCPToolset.from_server(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
)
```
Args:
connection_params: The connection parameters to the MCP server.
async_exit_stack: The async exit stack to use. If not provided, a new
AsyncExitStack will be created.
Returns:
A tuple of the list of MCPTools and the AsyncExitStack.
- tools: The list of MCPTools.
- async_exit_stack: The AsyncExitStack used to manage the connection to
the MCP server. Use `await async_exit_stack.aclose()` to close the
connection when server shuts down.
"""
async_exit_stack = async_exit_stack or AsyncExitStack()
toolset = cls(
connection_params=connection_params,
exit_stack=async_exit_stack,
errlog=errlog,
)
await async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
return (tools, async_exit_stack)
async def _initialize(self) -> ClientSession: async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession.""" """Connects to the MCP Server and initializes the ClientSession."""
self.session = await self.session_manager.create_session() self.session = await self.session_manager.create_session()
return self.session return self.session
async def _exit(self): @override
async def close(self):
"""Closes the connection to MCP Server.""" """Closes the connection to MCP Server."""
await self.exit_stack.aclose() await self.exit_stack.aclose()
@retry_on_closed_resource('_initialize') @retry_on_closed_resource('_initialize')
async def load_tools(self) -> List[MCPTool]: @override
async def get_tools(
self,
readony_context: ReadonlyContext = None,
) -> List[MCPTool]:
"""Loads all tools from the MCP Server. """Loads all tools from the MCP Server.
Returns: Returns:
A list of MCPTools imported from the MCP Server. A list of MCPTools imported from the MCP Server.
""" """
if not self.session:
await self._initialize()
tools_response: ListToolsResult = await self.session.list_tools() tools_response: ListToolsResult = await self.session.list_tools()
return [ return [
MCPTool( MCPTool(
@ -248,19 +124,6 @@ class MCPToolset:
mcp_session_manager=self.session_manager, mcp_session_manager=self.session_manager,
) )
for tool in tools_response.tools for tool in tools_response.tools
if self.tool_predicate is None
or self.tool_predicate(tool, readony_context)
] ]
async def __aenter__(self):
try:
await self._initialize()
return self
except Exception as e:
raise e
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self._exit()