No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member
2025-04-17 19:50:22 +00:00
committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions

View File

@@ -0,0 +1,176 @@
from contextlib import AsyncExitStack
import functools
import sys
from typing import Any, TextIO
import anyio
from pydantic import BaseModel
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
) from e
else:
raise e
class SseServerParams(BaseModel):
"""Parameters for the MCP SSE connection.
See MCP SSE Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
"""
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
def retry_on_closed_resource(async_reinit_func_name: str):
"""Decorator to automatically reinitialize session and retry action.
When MCP session was closed, the decorator will automatically recreate the
session and retry the action with the same parameters.
Note:
1. async_reinit_func_name is the name of the class member function that
reinitializes the MCP session.
2. Both the decorated function and the async_reinit_func_name must be async
functions.
Usage:
class MCPTool:
...
async def create_session(self):
self.session = ...
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
Args:
async_reinit_func_name: The name of the async function to recreate session.
Returns:
The decorated function.
"""
def decorator(func):
@functools.wraps(
func
) # Preserves original function metadata (name, docstring)
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
try:
if hasattr(self, async_reinit_func_name) and callable(
getattr(self, async_reinit_func_name)
):
async_init_fn = getattr(self, async_reinit_func_name)
await async_init_fn()
else:
raise ValueError(
f'Function {async_reinit_func_name} does not exist in decorated'
' class. Please check the function name in'
' retry_on_closed_resource decorator.'
)
except Exception as reinit_err:
raise RuntimeError(
f'Error reinitializing: {reinit_err}'
) from reinit_err
return await func(self, *args, **kwargs)
return wrapper
return decorator
class MCPSessionManager:
"""Manages MCP client sessions.
This class provides methods for creating and initializing MCP client sessions,
handling different connection parameters (Stdio and SSE).
"""
def __init__(
self,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes the MCP session manager.
Example usage:
```
mcp_session_manager = MCPSessionManager(
connection_params=connection_params,
exit_stack=exit_stack,
)
session = await mcp_session_manager.create_session()
```
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self.connection_params = connection_params
self.exit_stack = exit_stack
self.errlog = errlog
async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes an MCP client session.
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
Returns:
ClientSession: The initialized MCP client session.
"""
if isinstance(connection_params, StdioServerParameters):
client = stdio_client(server=connection_params, errlog=errlog)
elif isinstance(connection_params, SseServerParams):
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
timeout=connection_params.timeout,
sse_read_timeout=connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {connection_params}'
)
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session

View File

@@ -17,6 +17,8 @@ from typing import Optional
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
@@ -33,6 +35,7 @@ except ImportError as e:
else:
raise e
from ..base_tool import BaseTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
self,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
):
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
self.description = mcp_tool.description if mcp_tool.description else ""
self.mcp_tool = mcp_tool
self.mcp_session = mcp_session
self.mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server.
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
async def _reinitialize_session(self):
self.mcp_session = await self.mcp_session_manager.create_session()
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
return function_decl
@override
@retry_on_closed_resource("_reinitialize_session")
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
Any: The response from the tool.
"""
# TODO(cheliu): Support passing tool context to MCP Server.
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
try:
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
except Exception as e:
print(e)
raise e

View File

@@ -13,15 +13,16 @@
# limitations under the License.
from contextlib import AsyncExitStack
import sys
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type
from typing import List, Optional, TextIO, Tuple, Type
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult
except ImportError as e:
import sys
@@ -34,18 +35,9 @@ except ImportError as e:
else:
raise e
from pydantic import BaseModel
from .mcp_tool import MCPTool
class SseServerParams(BaseModel):
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
class MCPToolset:
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@@ -110,7 +102,11 @@ class MCPToolset:
"""
def __init__(
self, *, connection_params: StdioServerParameters | SseServerParams
self,
*,
connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr,
exit_stack=AsyncExitStack(),
):
"""Initializes the MCPToolset.
@@ -175,7 +171,14 @@ class MCPToolset:
if not connection_params:
raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params
self.exit_stack = AsyncExitStack()
self.errlog = errlog
self.exit_stack = exit_stack
self.session_manager = MCPSessionManager(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def from_server(
@@ -183,6 +186,7 @@ class MCPToolset:
*,
connection_params: StdioServerParameters | SseServerParams,
async_exit_stack: Optional[AsyncExitStack] = None,
errlog: TextIO = sys.stderr,
) -> Tuple[List[MCPTool], AsyncExitStack]:
"""Retrieve all tools from the MCP connection.
@@ -209,41 +213,27 @@ class MCPToolset:
the MCP server. Use `await async_exit_stack.aclose()` to close the
connection when server shuts down.
"""
toolset = cls(connection_params=connection_params)
async_exit_stack = async_exit_stack or AsyncExitStack()
toolset = cls(
connection_params=connection_params,
exit_stack=async_exit_stack,
errlog=errlog,
)
await async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
return (tools, async_exit_stack)
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
if isinstance(self.connection_params, StdioServerParameters):
client = stdio_client(self.connection_params)
elif isinstance(self.connection_params, SseServerParams):
client = sse_client(
url=self.connection_params.url,
headers=self.connection_params.headers,
timeout=self.connection_params.timeout,
sse_read_timeout=self.connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self.connection_params}'
)
transports = await self.exit_stack.enter_async_context(client)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*transports)
)
await self.session.initialize()
self.session = await self.session_manager.create_session()
return self.session
async def _exit(self):
"""Closes the connection to MCP Server."""
await self.exit_stack.aclose()
@retry_on_closed_resource('_initialize')
async def load_tools(self) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
@@ -252,7 +242,11 @@ class MCPToolset:
"""
tools_response: ListToolsResult = await self.session.list_tools()
return [
MCPTool(mcp_tool=tool, mcp_session=self.session)
MCPTool(
mcp_tool=tool,
mcp_session=self.session,
mcp_session_manager=self.session_manager,
)
for tool in tools_response.tools
]