From a2263b18083fccd176b3ed61c368a65d9b7efde4 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 21 May 2025 23:54:01 -0700 Subject: [PATCH] refactor: refactor toolset to extract tool_filter logic to base class PiperOrigin-RevId: 761828251 --- .../adk/tools/apihub_tool/apihub_toolset.py | 2 +- .../application_integration_toolset.py | 8 +++++-- src/google/adk/tools/base_toolset.py | 24 ++++++++++++++++++- .../google_api_tool/google_api_toolset.py | 14 ----------- src/google/adk/tools/mcp_tool/mcp_toolset.py | 16 ++----------- .../openapi_spec_parser/openapi_toolset.py | 9 ++----- 6 files changed, 34 insertions(+), 39 deletions(-) diff --git a/src/google/adk/tools/apihub_tool/apihub_toolset.py b/src/google/adk/tools/apihub_tool/apihub_toolset.py index 9bd2a34..62c183a 100644 --- a/src/google/adk/tools/apihub_tool/apihub_toolset.py +++ b/src/google/adk/tools/apihub_tool/apihub_toolset.py @@ -131,6 +131,7 @@ class APIHubToolset(BaseToolset): be either a tool predicate or a list of tool names of the tools to expose. """ + super().__init__(tool_filter=tool_filter) self.name = name self.description = description self._apihub_resource_name = apihub_resource_name @@ -143,7 +144,6 @@ class APIHubToolset(BaseToolset): self._openapi_toolset = None self._auth_scheme = auth_scheme self._auth_credential = auth_credential - self.tool_filter = tool_filter if not self._lazy_load_spec: self._prepare_toolset() diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index a26b24f..c3d7c28 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -128,6 +128,7 @@ class ApplicationIntegrationToolset(BaseToolset): Exception: If there is an error during the initialization of the integration or connection client. """ + super().__init__(tool_filter=tool_filter) self.project = project self.location = location self._integration = integration @@ -140,7 +141,6 @@ class ApplicationIntegrationToolset(BaseToolset): self._service_account_json = service_account_json self._auth_scheme = auth_scheme self._auth_credential = auth_credential - self.tool_filter = tool_filter integration_client = IntegrationClient( project, @@ -263,7 +263,11 @@ class ApplicationIntegrationToolset(BaseToolset): readonly_context: Optional[ReadonlyContext] = None, ) -> List[RestApiTool]: return ( - self._tools + [ + tool + for tool in self._tools + if self._is_tool_selected(tool, readonly_context) + ] if self._openapi_toolset is None else await self._openapi_toolset.get_tools(readonly_context) ) diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index a761f2c..f5ca333 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -1,8 +1,10 @@ from abc import ABC from abc import abstractmethod +from typing import List from typing import Optional from typing import Protocol from typing import runtime_checkable +from typing import Union from ..agents.readonly_context import ReadonlyContext from .base_tool import BaseTool @@ -34,9 +36,15 @@ class BaseToolset(ABC): A toolset is a collection of tools that can be used by an agent. """ + def __init__( + self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None + ): + self.tool_filter = tool_filter + @abstractmethod async def get_tools( - self, readonly_context: Optional[ReadonlyContext] = None + self, + readonly_context: Optional[ReadonlyContext] = None, ) -> list[BaseTool]: """Return all tools in the toolset based on the provided context. @@ -57,3 +65,17 @@ class BaseToolset(ABC): should ensure that any open connections, files, or other managed resources are properly released to prevent leaks. """ + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if not self.tool_filter: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False diff --git a/src/google/adk/tools/google_api_tool/google_api_toolset.py b/src/google/adk/tools/google_api_tool/google_api_toolset.py index 7131ebc..2cb00fa 100644 --- a/src/google/adk/tools/google_api_tool/google_api_toolset.py +++ b/src/google/adk/tools/google_api_tool/google_api_toolset.py @@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset): self._openapi_toolset = self._load_toolset_with_oidc_auth() self.tool_filter = tool_filter - def _is_tool_selected( - self, tool: GoogleApiTool, readonly_context: ReadonlyContext - ) -> bool: - if not self.tool_filter: - return True - - if isinstance(self.tool_filter, ToolPredicate): - return self.tool_filter(tool, readonly_context) - - if isinstance(self.tool_filter, list): - return tool.name in self.tool_filter - - return False - @override async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index e013bcb..994f6b9 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -90,6 +90,7 @@ class MCPToolset(BaseToolset): if not connection_params: raise ValueError("Missing connection params in MCPToolset.") + super().__init__(tool_filter=tool_filter) self._connection_params = connection_params self._errlog = errlog self._exit_stack = AsyncExitStack() @@ -102,7 +103,6 @@ class MCPToolset(BaseToolset): errlog=self._errlog, ) self._session = None - self.tool_filter = tool_filter self._initialized = False async def _initialize(self) -> ClientSession: @@ -116,18 +116,6 @@ class MCPToolset(BaseToolset): self._initialized = True return self._session - def _is_selected( - self, tool: BaseTool, readonly_context: Optional[ReadonlyContext] - ) -> bool: - """Checks if a tool should be selected based on the tool filter.""" - if self.tool_filter is None: - return True - if isinstance(self.tool_filter, ToolPredicate): - return self.tool_filter(tool, readonly_context) - if isinstance(self.tool_filter, list): - return tool.name in self.tool_filter - return False - @override async def close(self): """Safely closes the connection to MCP Server with guaranteed resource cleanup.""" @@ -221,6 +209,6 @@ class MCPToolset(BaseToolset): mcp_session_manager=self._session_manager, ) - if self._is_selected(mcp_tool, readonly_context): + if self._is_tool_selected(mcp_tool, readonly_context): tools.append(mcp_tool) return tools diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py index 495a228..8b01218 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py @@ -103,12 +103,12 @@ class OpenAPIToolset(BaseToolset): tool_filter: The filter used to filter the tools in the toolset. It can be either a tool predicate or a list of tool names of the tools to expose. """ + super().__init__(tool_filter=tool_filter) if not spec_dict: spec_dict = self._load_spec(spec_str, spec_str_type) self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict)) if auth_scheme or auth_credential: self._configure_auth_all(auth_scheme, auth_credential) - self.tool_filter = tool_filter def _configure_auth_all( self, auth_scheme: AuthScheme, auth_credential: AuthCredential @@ -129,12 +129,7 @@ class OpenAPIToolset(BaseToolset): return [ tool for tool in self._tools - if self.tool_filter is None - or ( - self.tool_filter(tool, readonly_context) - if isinstance(self.tool_filter, ToolPredicate) - else tool.name in self.tool_filter - ) + if self._is_tool_selected(tool, readonly_context) ] def get_tool(self, tool_name: str) -> Optional[RestApiTool]: