refactor: refactor toolset to extract tool_filter logic to base class

PiperOrigin-RevId: 761828251
This commit is contained in:
Xiang (Sean) Zhou 2025-05-21 23:54:01 -07:00 committed by Copybara-Service
parent e0851a1e57
commit a2263b1808
6 changed files with 34 additions and 39 deletions

View File

@ -131,6 +131,7 @@ class APIHubToolset(BaseToolset):
be either a tool predicate or a list of tool names of the tools to
expose.
"""
super().__init__(tool_filter=tool_filter)
self.name = name
self.description = description
self._apihub_resource_name = apihub_resource_name
@ -143,7 +144,6 @@ class APIHubToolset(BaseToolset):
self._openapi_toolset = None
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self.tool_filter = tool_filter
if not self._lazy_load_spec:
self._prepare_toolset()

View File

@ -128,6 +128,7 @@ class ApplicationIntegrationToolset(BaseToolset):
Exception: If there is an error during the initialization of the
integration or connection client.
"""
super().__init__(tool_filter=tool_filter)
self.project = project
self.location = location
self._integration = integration
@ -140,7 +141,6 @@ class ApplicationIntegrationToolset(BaseToolset):
self._service_account_json = service_account_json
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self.tool_filter = tool_filter
integration_client = IntegrationClient(
project,
@ -263,7 +263,11 @@ class ApplicationIntegrationToolset(BaseToolset):
readonly_context: Optional[ReadonlyContext] = None,
) -> List[RestApiTool]:
return (
self._tools
[
tool
for tool in self._tools
if self._is_tool_selected(tool, readonly_context)
]
if self._openapi_toolset is None
else await self._openapi_toolset.get_tools(readonly_context)
)

View File

@ -1,8 +1,10 @@
from abc import ABC
from abc import abstractmethod
from typing import List
from typing import Optional
from typing import Protocol
from typing import runtime_checkable
from typing import Union
from ..agents.readonly_context import ReadonlyContext
from .base_tool import BaseTool
@ -34,9 +36,15 @@ class BaseToolset(ABC):
A toolset is a collection of tools that can be used by an agent.
"""
def __init__(
self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None
):
self.tool_filter = tool_filter
@abstractmethod
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
self,
readonly_context: Optional[ReadonlyContext] = None,
) -> list[BaseTool]:
"""Return all tools in the toolset based on the provided context.
@ -57,3 +65,17 @@ class BaseToolset(ABC):
should ensure that any open connections, files, or other managed
resources are properly released to prevent leaks.
"""
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if not self.tool_filter:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False

View File

@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset):
self._openapi_toolset = self._load_toolset_with_oidc_auth()
self.tool_filter = tool_filter
def _is_tool_selected(
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
) -> bool:
if not self.tool_filter:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None

View File

@ -90,6 +90,7 @@ class MCPToolset(BaseToolset):
if not connection_params:
raise ValueError("Missing connection params in MCPToolset.")
super().__init__(tool_filter=tool_filter)
self._connection_params = connection_params
self._errlog = errlog
self._exit_stack = AsyncExitStack()
@ -102,7 +103,6 @@ class MCPToolset(BaseToolset):
errlog=self._errlog,
)
self._session = None
self.tool_filter = tool_filter
self._initialized = False
async def _initialize(self) -> ClientSession:
@ -116,18 +116,6 @@ class MCPToolset(BaseToolset):
self._initialized = True
return self._session
def _is_selected(
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
) -> bool:
"""Checks if a tool should be selected based on the tool filter."""
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def close(self):
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
@ -221,6 +209,6 @@ class MCPToolset(BaseToolset):
mcp_session_manager=self._session_manager,
)
if self._is_selected(mcp_tool, readonly_context):
if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
return tools

View File

@ -103,12 +103,12 @@ class OpenAPIToolset(BaseToolset):
tool_filter: The filter used to filter the tools in the toolset. It can be
either a tool predicate or a list of tool names of the tools to expose.
"""
super().__init__(tool_filter=tool_filter)
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
@ -129,12 +129,7 @@ class OpenAPIToolset(BaseToolset):
return [
tool
for tool in self._tools
if self.tool_filter is None
or (
self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, ToolPredicate)
else tool.name in self.tool_filter
)
if self._is_tool_selected(tool, readonly_context)
]
def get_tool(self, tool_name: str) -> Optional[RestApiTool]: