mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
refactor: refactor toolset to extract tool_filter logic to base class
PiperOrigin-RevId: 761828251
This commit is contained in:
parent
e0851a1e57
commit
a2263b1808
@ -131,6 +131,7 @@ class APIHubToolset(BaseToolset):
|
||||
be either a tool predicate or a list of tool names of the tools to
|
||||
expose.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
self.name = name
|
||||
self.description = description
|
||||
self._apihub_resource_name = apihub_resource_name
|
||||
@ -143,7 +144,6 @@ class APIHubToolset(BaseToolset):
|
||||
self._openapi_toolset = None
|
||||
self._auth_scheme = auth_scheme
|
||||
self._auth_credential = auth_credential
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
if not self._lazy_load_spec:
|
||||
self._prepare_toolset()
|
||||
|
@ -128,6 +128,7 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
Exception: If there is an error during the initialization of the
|
||||
integration or connection client.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
self.project = project
|
||||
self.location = location
|
||||
self._integration = integration
|
||||
@ -140,7 +141,6 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
self._service_account_json = service_account_json
|
||||
self._auth_scheme = auth_scheme
|
||||
self._auth_credential = auth_credential
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
integration_client = IntegrationClient(
|
||||
project,
|
||||
@ -263,7 +263,11 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
readonly_context: Optional[ReadonlyContext] = None,
|
||||
) -> List[RestApiTool]:
|
||||
return (
|
||||
self._tools
|
||||
[
|
||||
tool
|
||||
for tool in self._tools
|
||||
if self._is_tool_selected(tool, readonly_context)
|
||||
]
|
||||
if self._openapi_toolset is None
|
||||
else await self._openapi_toolset.get_tools(readonly_context)
|
||||
)
|
||||
|
@ -1,8 +1,10 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import runtime_checkable
|
||||
from typing import Union
|
||||
|
||||
from ..agents.readonly_context import ReadonlyContext
|
||||
from .base_tool import BaseTool
|
||||
@ -34,9 +36,15 @@ class BaseToolset(ABC):
|
||||
A toolset is a collection of tools that can be used by an agent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None
|
||||
):
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
@abstractmethod
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
self,
|
||||
readonly_context: Optional[ReadonlyContext] = None,
|
||||
) -> list[BaseTool]:
|
||||
"""Return all tools in the toolset based on the provided context.
|
||||
|
||||
@ -57,3 +65,17 @@ class BaseToolset(ABC):
|
||||
should ensure that any open connections, files, or other managed
|
||||
resources are properly released to prevent leaks.
|
||||
"""
|
||||
|
||||
def _is_tool_selected(
|
||||
self, tool: BaseTool, readonly_context: ReadonlyContext
|
||||
) -> bool:
|
||||
if not self.tool_filter:
|
||||
return True
|
||||
|
||||
if isinstance(self.tool_filter, ToolPredicate):
|
||||
return self.tool_filter(tool, readonly_context)
|
||||
|
||||
if isinstance(self.tool_filter, list):
|
||||
return tool.name in self.tool_filter
|
||||
|
||||
return False
|
||||
|
@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset):
|
||||
self._openapi_toolset = self._load_toolset_with_oidc_auth()
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
def _is_tool_selected(
|
||||
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
|
||||
) -> bool:
|
||||
if not self.tool_filter:
|
||||
return True
|
||||
|
||||
if isinstance(self.tool_filter, ToolPredicate):
|
||||
return self.tool_filter(tool, readonly_context)
|
||||
|
||||
if isinstance(self.tool_filter, list):
|
||||
return tool.name in self.tool_filter
|
||||
|
||||
return False
|
||||
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
|
@ -90,6 +90,7 @@ class MCPToolset(BaseToolset):
|
||||
|
||||
if not connection_params:
|
||||
raise ValueError("Missing connection params in MCPToolset.")
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
self._connection_params = connection_params
|
||||
self._errlog = errlog
|
||||
self._exit_stack = AsyncExitStack()
|
||||
@ -102,7 +103,6 @@ class MCPToolset(BaseToolset):
|
||||
errlog=self._errlog,
|
||||
)
|
||||
self._session = None
|
||||
self.tool_filter = tool_filter
|
||||
self._initialized = False
|
||||
|
||||
async def _initialize(self) -> ClientSession:
|
||||
@ -116,18 +116,6 @@ class MCPToolset(BaseToolset):
|
||||
self._initialized = True
|
||||
return self._session
|
||||
|
||||
def _is_selected(
|
||||
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
|
||||
) -> bool:
|
||||
"""Checks if a tool should be selected based on the tool filter."""
|
||||
if self.tool_filter is None:
|
||||
return True
|
||||
if isinstance(self.tool_filter, ToolPredicate):
|
||||
return self.tool_filter(tool, readonly_context)
|
||||
if isinstance(self.tool_filter, list):
|
||||
return tool.name in self.tool_filter
|
||||
return False
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
|
||||
@ -221,6 +209,6 @@ class MCPToolset(BaseToolset):
|
||||
mcp_session_manager=self._session_manager,
|
||||
)
|
||||
|
||||
if self._is_selected(mcp_tool, readonly_context):
|
||||
if self._is_tool_selected(mcp_tool, readonly_context):
|
||||
tools.append(mcp_tool)
|
||||
return tools
|
||||
|
@ -103,12 +103,12 @@ class OpenAPIToolset(BaseToolset):
|
||||
tool_filter: The filter used to filter the tools in the toolset. It can be
|
||||
either a tool predicate or a list of tool names of the tools to expose.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
def _configure_auth_all(
|
||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||
@ -129,12 +129,7 @@ class OpenAPIToolset(BaseToolset):
|
||||
return [
|
||||
tool
|
||||
for tool in self._tools
|
||||
if self.tool_filter is None
|
||||
or (
|
||||
self.tool_filter(tool, readonly_context)
|
||||
if isinstance(self.tool_filter, ToolPredicate)
|
||||
else tool.name in self.tool_filter
|
||||
)
|
||||
if self._is_tool_selected(tool, readonly_context)
|
||||
]
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||
|
Loading…
Reference in New Issue
Block a user