refactor: refactor toolset to extract tool_filter logic to base class

PiperOrigin-RevId: 761828251
This commit is contained in:
Xiang (Sean) Zhou
2025-05-21 23:54:01 -07:00
committed by Copybara-Service
parent e0851a1e57
commit a2263b1808
6 changed files with 34 additions and 39 deletions
+2 -14
View File
@@ -90,6 +90,7 @@ class MCPToolset(BaseToolset):
if not connection_params:
raise ValueError("Missing connection params in MCPToolset.")
super().__init__(tool_filter=tool_filter)
self._connection_params = connection_params
self._errlog = errlog
self._exit_stack = AsyncExitStack()
@@ -102,7 +103,6 @@ class MCPToolset(BaseToolset):
errlog=self._errlog,
)
self._session = None
self.tool_filter = tool_filter
self._initialized = False
async def _initialize(self) -> ClientSession:
@@ -116,18 +116,6 @@ class MCPToolset(BaseToolset):
self._initialized = True
return self._session
def _is_selected(
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
) -> bool:
"""Checks if a tool should be selected based on the tool filter."""
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def close(self):
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
@@ -221,6 +209,6 @@ class MCPToolset(BaseToolset):
mcp_session_manager=self._session_manager,
)
if self._is_selected(mcp_tool, readonly_context):
if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
return tools