refactor: refactor toolset to extract tool_filter logic to base class

PiperOrigin-RevId: 761828251
This commit is contained in:
Xiang (Sean) Zhou
2025-05-21 23:54:01 -07:00
committed by Copybara-Service
parent e0851a1e57
commit a2263b1808
6 changed files with 34 additions and 39 deletions
@@ -103,12 +103,12 @@ class OpenAPIToolset(BaseToolset):
tool_filter: The filter used to filter the tools in the toolset. It can be
either a tool predicate or a list of tool names of the tools to expose.
"""
super().__init__(tool_filter=tool_filter)
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
@@ -129,12 +129,7 @@ class OpenAPIToolset(BaseToolset):
return [
tool
for tool in self._tools
if self.tool_filter is None
or (
self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, ToolPredicate)
else tool.name in self.tool_filter
)
if self._is_tool_selected(tool, readonly_context)
]
def get_tool(self, tool_name: str) -> Optional[RestApiTool]: