refactor: refactor mcp toolset to hide non-public field

PiperOrigin-RevId: 758494601
This commit is contained in:
Xiang (Sean) Zhou 2025-05-13 21:05:22 -07:00 committed by Copybara-Service
parent 14cf910ce6
commit fc40226ec0
3 changed files with 31 additions and 29 deletions

View File

@ -138,15 +138,15 @@ class MCPSessionManager:
errlog: (Optional) TextIO stream for error logging. Use only for errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session. initializing a local stdio MCP session.
""" """
self.connection_params = connection_params self._connection_params = connection_params
self.exit_stack = exit_stack self._exit_stack = exit_stack
self.errlog = errlog self._errlog = errlog
async def create_session(self) -> ClientSession: async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session( return await MCPSessionManager.initialize_session(
connection_params=self.connection_params, connection_params=self._connection_params,
exit_stack=self.exit_stack, exit_stack=self._exit_stack,
errlog=self.errlog, errlog=self._errlog,
) )
@classmethod @classmethod

View File

@ -81,15 +81,15 @@ class MCPTool(BaseTool):
raise ValueError("mcp_session cannot be None") raise ValueError("mcp_session cannot be None")
self.name = mcp_tool.name self.name = mcp_tool.name
self.description = mcp_tool.description if mcp_tool.description else "" self.description = mcp_tool.description if mcp_tool.description else ""
self.mcp_tool = mcp_tool self._mcp_tool = mcp_tool
self.mcp_session = mcp_session self._mcp_session = mcp_session
self.mcp_session_manager = mcp_session_manager self._mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server. # TODO(cheliu): Support passing auth to MCP Server.
self.auth_scheme = auth_scheme self._auth_scheme = auth_scheme
self.auth_credential = auth_credential self._auth_credential = auth_credential
async def _reinitialize_session(self): async def _reinitialize_session(self):
self.mcp_session = await self.mcp_session_manager.create_session() self._mcp_session = await self._mcp_session_manager.create_session()
@override @override
def _get_declaration(self) -> FunctionDeclaration: def _get_declaration(self) -> FunctionDeclaration:
@ -98,7 +98,7 @@ class MCPTool(BaseTool):
Returns: Returns:
FunctionDeclaration: The Gemini function declaration for the tool. FunctionDeclaration: The Gemini function declaration for the tool.
""" """
schema_dict = self.mcp_tool.inputSchema schema_dict = self._mcp_tool.inputSchema
parameters = to_gemini_schema(schema_dict) parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration( function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters name=self.name, description=self.description, parameters=parameters
@ -119,7 +119,7 @@ class MCPTool(BaseTool):
""" """
# TODO(cheliu): Support passing tool context to MCP Server. # TODO(cheliu): Support passing tool context to MCP Server.
try: try:
response = await self.mcp_session.call_tool(self.name, arguments=args) response = await self._mcp_session.call_tool(self.name, arguments=args)
return response return response
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -76,26 +76,28 @@ class MCPToolset(BaseToolset):
connection_params: The connection parameters to the MCP server. Can be: connection_params: The connection parameters to the MCP server. Can be:
`StdioServerParameters` for using local mcp server (e.g. using `npx` or `StdioServerParameters` for using local mcp server (e.g. using `npx` or
`python3`); or `SseServerParams` for a local/remote SSE server. `python3`); or `SseServerParams` for a local/remote SSE server.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
""" """
if not connection_params: if not connection_params:
raise ValueError('Missing connection params in MCPToolset.') raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params self._connection_params = connection_params
self.errlog = errlog self._errlog = errlog
self.exit_stack = AsyncExitStack() self._exit_stack = AsyncExitStack()
self.session_manager = MCPSessionManager( self._session_manager = MCPSessionManager(
connection_params=self.connection_params, connection_params=self._connection_params,
exit_stack=self.exit_stack, exit_stack=self._exit_stack,
errlog=self.errlog, errlogger=self._errlog,
) )
self.session = None self._session = None
self.tool_filter = tool_filter self.tool_filter = tool_filter
async def _initialize(self) -> ClientSession: async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession.""" """Connects to the MCP Server and initializes the ClientSession."""
self.session = await self.session_manager.create_session() self._session = await self._session_manager.create_session()
return self.session return self._session
def _is_selected( def _is_selected(
self, tool: ..., readonly_context: Optional[ReadonlyContext] self, tool: ..., readonly_context: Optional[ReadonlyContext]
@ -112,7 +114,7 @@ class MCPToolset(BaseToolset):
@override @override
async def close(self): async def close(self):
"""Closes the connection to MCP Server.""" """Closes the connection to MCP Server."""
await self.exit_stack.aclose() await self._exit_stack.aclose()
@retry_on_closed_resource('_initialize') @retry_on_closed_resource('_initialize')
@override @override
@ -125,14 +127,14 @@ class MCPToolset(BaseToolset):
Returns: Returns:
A list of MCPTools imported from the MCP Server. A list of MCPTools imported from the MCP Server.
""" """
if not self.session: if not self._session:
await self._initialize() await self._initialize()
tools_response: ListToolsResult = await self.session.list_tools() tools_response: ListToolsResult = await self._session.list_tools()
return [ return [
MCPTool( MCPTool(
mcp_tool=tool, mcp_tool=tool,
mcp_session=self.session, mcp_session=self._session,
mcp_session_manager=self.session_manager, mcp_session_manager=self._session_manager,
) )
for tool in tools_response.tools for tool in tools_response.tools
if self._is_selected(tool, readonly_context) if self._is_selected(tool, readonly_context)