refactor: refactor toolset to extract tool_filter logic to base class

PiperOrigin-RevId: 761828251
This commit is contained in:
Xiang (Sean) Zhou 2025-05-21 23:54:01 -07:00 committed by Copybara-Service
parent e0851a1e57
commit a2263b1808
6 changed files with 34 additions and 39 deletions

View File

@ -131,6 +131,7 @@ class APIHubToolset(BaseToolset):
be either a tool predicate or a list of tool names of the tools to be either a tool predicate or a list of tool names of the tools to
expose. expose.
""" """
super().__init__(tool_filter=tool_filter)
self.name = name self.name = name
self.description = description self.description = description
self._apihub_resource_name = apihub_resource_name self._apihub_resource_name = apihub_resource_name
@ -143,7 +144,6 @@ class APIHubToolset(BaseToolset):
self._openapi_toolset = None self._openapi_toolset = None
self._auth_scheme = auth_scheme self._auth_scheme = auth_scheme
self._auth_credential = auth_credential self._auth_credential = auth_credential
self.tool_filter = tool_filter
if not self._lazy_load_spec: if not self._lazy_load_spec:
self._prepare_toolset() self._prepare_toolset()

View File

@ -128,6 +128,7 @@ class ApplicationIntegrationToolset(BaseToolset):
Exception: If there is an error during the initialization of the Exception: If there is an error during the initialization of the
integration or connection client. integration or connection client.
""" """
super().__init__(tool_filter=tool_filter)
self.project = project self.project = project
self.location = location self.location = location
self._integration = integration self._integration = integration
@ -140,7 +141,6 @@ class ApplicationIntegrationToolset(BaseToolset):
self._service_account_json = service_account_json self._service_account_json = service_account_json
self._auth_scheme = auth_scheme self._auth_scheme = auth_scheme
self._auth_credential = auth_credential self._auth_credential = auth_credential
self.tool_filter = tool_filter
integration_client = IntegrationClient( integration_client = IntegrationClient(
project, project,
@ -263,7 +263,11 @@ class ApplicationIntegrationToolset(BaseToolset):
readonly_context: Optional[ReadonlyContext] = None, readonly_context: Optional[ReadonlyContext] = None,
) -> List[RestApiTool]: ) -> List[RestApiTool]:
return ( return (
self._tools [
tool
for tool in self._tools
if self._is_tool_selected(tool, readonly_context)
]
if self._openapi_toolset is None if self._openapi_toolset is None
else await self._openapi_toolset.get_tools(readonly_context) else await self._openapi_toolset.get_tools(readonly_context)
) )

View File

@ -1,8 +1,10 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import List
from typing import Optional from typing import Optional
from typing import Protocol from typing import Protocol
from typing import runtime_checkable from typing import runtime_checkable
from typing import Union
from ..agents.readonly_context import ReadonlyContext from ..agents.readonly_context import ReadonlyContext
from .base_tool import BaseTool from .base_tool import BaseTool
@ -34,9 +36,15 @@ class BaseToolset(ABC):
A toolset is a collection of tools that can be used by an agent. A toolset is a collection of tools that can be used by an agent.
""" """
def __init__(
self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None
):
self.tool_filter = tool_filter
@abstractmethod @abstractmethod
async def get_tools( async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None self,
readonly_context: Optional[ReadonlyContext] = None,
) -> list[BaseTool]: ) -> list[BaseTool]:
"""Return all tools in the toolset based on the provided context. """Return all tools in the toolset based on the provided context.
@ -57,3 +65,17 @@ class BaseToolset(ABC):
should ensure that any open connections, files, or other managed should ensure that any open connections, files, or other managed
resources are properly released to prevent leaks. resources are properly released to prevent leaks.
""" """
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if not self.tool_filter:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False

View File

@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset):
self._openapi_toolset = self._load_toolset_with_oidc_auth() self._openapi_toolset = self._load_toolset_with_oidc_auth()
self.tool_filter = tool_filter self.tool_filter = tool_filter
def _is_tool_selected(
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
) -> bool:
if not self.tool_filter:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override @override
async def get_tools( async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None self, readonly_context: Optional[ReadonlyContext] = None

View File

@ -90,6 +90,7 @@ class MCPToolset(BaseToolset):
if not connection_params: if not connection_params:
raise ValueError("Missing connection params in MCPToolset.") raise ValueError("Missing connection params in MCPToolset.")
super().__init__(tool_filter=tool_filter)
self._connection_params = connection_params self._connection_params = connection_params
self._errlog = errlog self._errlog = errlog
self._exit_stack = AsyncExitStack() self._exit_stack = AsyncExitStack()
@ -102,7 +103,6 @@ class MCPToolset(BaseToolset):
errlog=self._errlog, errlog=self._errlog,
) )
self._session = None self._session = None
self.tool_filter = tool_filter
self._initialized = False self._initialized = False
async def _initialize(self) -> ClientSession: async def _initialize(self) -> ClientSession:
@ -116,18 +116,6 @@ class MCPToolset(BaseToolset):
self._initialized = True self._initialized = True
return self._session return self._session
def _is_selected(
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
) -> bool:
"""Checks if a tool should be selected based on the tool filter."""
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override @override
async def close(self): async def close(self):
"""Safely closes the connection to MCP Server with guaranteed resource cleanup.""" """Safely closes the connection to MCP Server with guaranteed resource cleanup."""
@ -221,6 +209,6 @@ class MCPToolset(BaseToolset):
mcp_session_manager=self._session_manager, mcp_session_manager=self._session_manager,
) )
if self._is_selected(mcp_tool, readonly_context): if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool) tools.append(mcp_tool)
return tools return tools

View File

@ -103,12 +103,12 @@ class OpenAPIToolset(BaseToolset):
tool_filter: The filter used to filter the tools in the toolset. It can be tool_filter: The filter used to filter the tools in the toolset. It can be
either a tool predicate or a list of tool names of the tools to expose. either a tool predicate or a list of tool names of the tools to expose.
""" """
super().__init__(tool_filter=tool_filter)
if not spec_dict: if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type) spec_dict = self._load_spec(spec_str, spec_str_type)
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict)) self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential: if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential) self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter
def _configure_auth_all( def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential self, auth_scheme: AuthScheme, auth_credential: AuthCredential
@ -129,12 +129,7 @@ class OpenAPIToolset(BaseToolset):
return [ return [
tool tool
for tool in self._tools for tool in self._tools
if self.tool_filter is None if self._is_tool_selected(tool, readonly_context)
or (
self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, ToolPredicate)
else tool.name in self.tool_filter
)
] ]
def get_tool(self, tool_name: str) -> Optional[RestApiTool]: def get_tool(self, tool_name: str) -> Optional[RestApiTool]: