mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-22 05:12:18 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
176
src/google/adk/tools/mcp_tool/mcp_session_manager.py
Normal file
176
src/google/adk/tools/mcp_tool/mcp_session_manager.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from contextlib import AsyncExitStack
|
||||
import functools
|
||||
import sys
|
||||
from typing import Any, TextIO
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
"""Parameters for the MCP SSE connection.
|
||||
|
||||
See MCP SSE Client documentation for more details.
|
||||
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
||||
"""
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
def retry_on_closed_resource(async_reinit_func_name: str):
|
||||
"""Decorator to automatically reinitialize session and retry action.
|
||||
|
||||
When MCP session was closed, the decorator will automatically recreate the
|
||||
session and retry the action with the same parameters.
|
||||
|
||||
Note:
|
||||
1. async_reinit_func_name is the name of the class member function that
|
||||
reinitializes the MCP session.
|
||||
2. Both the decorated function and the async_reinit_func_name must be async
|
||||
functions.
|
||||
|
||||
Usage:
|
||||
class MCPTool:
|
||||
...
|
||||
async def create_session(self):
|
||||
self.session = ...
|
||||
|
||||
@retry_on_closed_resource('create_session')
|
||||
async def use_session(self):
|
||||
await self.session.call_tool()
|
||||
|
||||
Args:
|
||||
async_reinit_func_name: The name of the async function to recreate session.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(
|
||||
func
|
||||
) # Preserves original function metadata (name, docstring)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except anyio.ClosedResourceError:
|
||||
try:
|
||||
if hasattr(self, async_reinit_func_name) and callable(
|
||||
getattr(self, async_reinit_func_name)
|
||||
):
|
||||
async_init_fn = getattr(self, async_reinit_func_name)
|
||||
await async_init_fn()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Function {async_reinit_func_name} does not exist in decorated'
|
||||
' class. Please check the function name in'
|
||||
' retry_on_closed_resource decorator.'
|
||||
)
|
||||
except Exception as reinit_err:
|
||||
raise RuntimeError(
|
||||
f'Error reinitializing: {reinit_err}'
|
||||
) from reinit_err
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class MCPSessionManager:
|
||||
"""Manages MCP client sessions.
|
||||
|
||||
This class provides methods for creating and initializing MCP client sessions,
|
||||
handling different connection parameters (Stdio and SSE).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
exit_stack: AsyncExitStack,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> ClientSession:
|
||||
"""Initializes the MCP session manager.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
mcp_session_manager = MCPSessionManager(
|
||||
connection_params=connection_params,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
session = await mcp_session_manager.create_session()
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
"""
|
||||
self.connection_params = connection_params
|
||||
self.exit_stack = exit_stack
|
||||
self.errlog = errlog
|
||||
|
||||
async def create_session(self) -> ClientSession:
|
||||
return await MCPSessionManager.initialize_session(
|
||||
connection_params=self.connection_params,
|
||||
exit_stack=self.exit_stack,
|
||||
errlog=self.errlog,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def initialize_session(
|
||||
cls,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
exit_stack: AsyncExitStack,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> ClientSession:
|
||||
"""Initializes an MCP client session.
|
||||
|
||||
Args:
|
||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
|
||||
Returns:
|
||||
ClientSession: The initialized MCP client session.
|
||||
"""
|
||||
if isinstance(connection_params, StdioServerParameters):
|
||||
client = stdio_client(server=connection_params, errlog=errlog)
|
||||
elif isinstance(connection_params, SseServerParams):
|
||||
client = sse_client(
|
||||
url=connection_params.url,
|
||||
headers=connection_params.headers,
|
||||
timeout=connection_params.timeout,
|
||||
sse_read_timeout=connection_params.sse_read_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unable to initialize connection. Connection should be'
|
||||
' StdioServerParameters or SseServerParams, but got'
|
||||
f' {connection_params}'
|
||||
)
|
||||
|
||||
transports = await exit_stack.enter_async_context(client)
|
||||
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
||||
await session.initialize()
|
||||
return session
|
||||
@@ -17,6 +17,8 @@ from typing import Optional
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
@@ -33,6 +35,7 @@ except ImportError as e:
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
from ..base_tool import BaseTool
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
|
||||
self,
|
||||
mcp_tool: McpBaseTool,
|
||||
mcp_session: ClientSession,
|
||||
mcp_session_manager: MCPSessionManager,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] | None = None,
|
||||
):
|
||||
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
|
||||
self.description = mcp_tool.description if mcp_tool.description else ""
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_session = mcp_session
|
||||
self.mcp_session_manager = mcp_session_manager
|
||||
# TODO(cheliu): Support passing auth to MCP Server.
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
async def _reinitialize_session(self):
|
||||
self.mcp_session = await self.mcp_session_manager.create_session()
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Gets the function declaration for the tool.
|
||||
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
@retry_on_closed_resource("_reinitialize_session")
|
||||
async def run_async(self, *, args, tool_context: ToolContext):
|
||||
"""Runs the tool asynchronously.
|
||||
|
||||
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
|
||||
Any: The response from the tool.
|
||||
"""
|
||||
# TODO(cheliu): Support passing tool context to MCP Server.
|
||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||
return response
|
||||
try:
|
||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, TextIO, Tuple, Type
|
||||
|
||||
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import ListToolsResult
|
||||
except ImportError as e:
|
||||
import sys
|
||||
@@ -34,18 +35,9 @@ except ImportError as e:
|
||||
else:
|
||||
raise e
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .mcp_tool import MCPTool
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class MCPToolset:
|
||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||
|
||||
@@ -110,7 +102,11 @@ class MCPToolset:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, connection_params: StdioServerParameters | SseServerParams
|
||||
self,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
errlog: TextIO = sys.stderr,
|
||||
exit_stack=AsyncExitStack(),
|
||||
):
|
||||
"""Initializes the MCPToolset.
|
||||
|
||||
@@ -175,7 +171,14 @@ class MCPToolset:
|
||||
if not connection_params:
|
||||
raise ValueError('Missing connection params in MCPToolset.')
|
||||
self.connection_params = connection_params
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.errlog = errlog
|
||||
self.exit_stack = exit_stack
|
||||
|
||||
self.session_manager = MCPSessionManager(
|
||||
connection_params=self.connection_params,
|
||||
exit_stack=self.exit_stack,
|
||||
errlog=self.errlog,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_server(
|
||||
@@ -183,6 +186,7 @@ class MCPToolset:
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
async_exit_stack: Optional[AsyncExitStack] = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
||||
"""Retrieve all tools from the MCP connection.
|
||||
|
||||
@@ -209,41 +213,27 @@ class MCPToolset:
|
||||
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
||||
connection when server shuts down.
|
||||
"""
|
||||
toolset = cls(connection_params=connection_params)
|
||||
async_exit_stack = async_exit_stack or AsyncExitStack()
|
||||
toolset = cls(
|
||||
connection_params=connection_params,
|
||||
exit_stack=async_exit_stack,
|
||||
errlog=errlog,
|
||||
)
|
||||
|
||||
await async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
return (tools, async_exit_stack)
|
||||
|
||||
async def _initialize(self) -> ClientSession:
|
||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||
if isinstance(self.connection_params, StdioServerParameters):
|
||||
client = stdio_client(self.connection_params)
|
||||
elif isinstance(self.connection_params, SseServerParams):
|
||||
client = sse_client(
|
||||
url=self.connection_params.url,
|
||||
headers=self.connection_params.headers,
|
||||
timeout=self.connection_params.timeout,
|
||||
sse_read_timeout=self.connection_params.sse_read_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unable to initialize connection. Connection should be'
|
||||
' StdioServerParameters or SseServerParams, but got'
|
||||
f' {self.connection_params}'
|
||||
)
|
||||
|
||||
transports = await self.exit_stack.enter_async_context(client)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(*transports)
|
||||
)
|
||||
await self.session.initialize()
|
||||
self.session = await self.session_manager.create_session()
|
||||
return self.session
|
||||
|
||||
async def _exit(self):
|
||||
"""Closes the connection to MCP Server."""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
@retry_on_closed_resource('_initialize')
|
||||
async def load_tools(self) -> List[MCPTool]:
|
||||
"""Loads all tools from the MCP Server.
|
||||
|
||||
@@ -252,7 +242,11 @@ class MCPToolset:
|
||||
"""
|
||||
tools_response: ListToolsResult = await self.session.list_tools()
|
||||
return [
|
||||
MCPTool(mcp_tool=tool, mcp_session=self.session)
|
||||
MCPTool(
|
||||
mcp_tool=tool,
|
||||
mcp_session=self.session,
|
||||
mcp_session_manager=self.session_manager,
|
||||
)
|
||||
for tool in tools_response.tools
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user