mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
refactor: refactor toolset to extract tool_filter logic to base class
PiperOrigin-RevId: 761828251
This commit is contained in:
parent
e0851a1e57
commit
a2263b1808
@ -131,6 +131,7 @@ class APIHubToolset(BaseToolset):
|
|||||||
be either a tool predicate or a list of tool names of the tools to
|
be either a tool predicate or a list of tool names of the tools to
|
||||||
expose.
|
expose.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(tool_filter=tool_filter)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
self._apihub_resource_name = apihub_resource_name
|
self._apihub_resource_name = apihub_resource_name
|
||||||
@ -143,7 +144,6 @@ class APIHubToolset(BaseToolset):
|
|||||||
self._openapi_toolset = None
|
self._openapi_toolset = None
|
||||||
self._auth_scheme = auth_scheme
|
self._auth_scheme = auth_scheme
|
||||||
self._auth_credential = auth_credential
|
self._auth_credential = auth_credential
|
||||||
self.tool_filter = tool_filter
|
|
||||||
|
|
||||||
if not self._lazy_load_spec:
|
if not self._lazy_load_spec:
|
||||||
self._prepare_toolset()
|
self._prepare_toolset()
|
||||||
|
@ -128,6 +128,7 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
Exception: If there is an error during the initialization of the
|
Exception: If there is an error during the initialization of the
|
||||||
integration or connection client.
|
integration or connection client.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(tool_filter=tool_filter)
|
||||||
self.project = project
|
self.project = project
|
||||||
self.location = location
|
self.location = location
|
||||||
self._integration = integration
|
self._integration = integration
|
||||||
@ -140,7 +141,6 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
self._service_account_json = service_account_json
|
self._service_account_json = service_account_json
|
||||||
self._auth_scheme = auth_scheme
|
self._auth_scheme = auth_scheme
|
||||||
self._auth_credential = auth_credential
|
self._auth_credential = auth_credential
|
||||||
self.tool_filter = tool_filter
|
|
||||||
|
|
||||||
integration_client = IntegrationClient(
|
integration_client = IntegrationClient(
|
||||||
project,
|
project,
|
||||||
@ -263,7 +263,11 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
readonly_context: Optional[ReadonlyContext] = None,
|
readonly_context: Optional[ReadonlyContext] = None,
|
||||||
) -> List[RestApiTool]:
|
) -> List[RestApiTool]:
|
||||||
return (
|
return (
|
||||||
self._tools
|
[
|
||||||
|
tool
|
||||||
|
for tool in self._tools
|
||||||
|
if self._is_tool_selected(tool, readonly_context)
|
||||||
|
]
|
||||||
if self._openapi_toolset is None
|
if self._openapi_toolset is None
|
||||||
else await self._openapi_toolset.get_tools(readonly_context)
|
else await self._openapi_toolset.get_tools(readonly_context)
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
from typing import runtime_checkable
|
from typing import runtime_checkable
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from ..agents.readonly_context import ReadonlyContext
|
from ..agents.readonly_context import ReadonlyContext
|
||||||
from .base_tool import BaseTool
|
from .base_tool import BaseTool
|
||||||
@ -34,9 +36,15 @@ class BaseToolset(ABC):
|
|||||||
A toolset is a collection of tools that can be used by an agent.
|
A toolset is a collection of tools that can be used by an agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None
|
||||||
|
):
|
||||||
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
self, readonly_context: Optional[ReadonlyContext] = None
|
self,
|
||||||
|
readonly_context: Optional[ReadonlyContext] = None,
|
||||||
) -> list[BaseTool]:
|
) -> list[BaseTool]:
|
||||||
"""Return all tools in the toolset based on the provided context.
|
"""Return all tools in the toolset based on the provided context.
|
||||||
|
|
||||||
@ -57,3 +65,17 @@ class BaseToolset(ABC):
|
|||||||
should ensure that any open connections, files, or other managed
|
should ensure that any open connections, files, or other managed
|
||||||
resources are properly released to prevent leaks.
|
resources are properly released to prevent leaks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _is_tool_selected(
|
||||||
|
self, tool: BaseTool, readonly_context: ReadonlyContext
|
||||||
|
) -> bool:
|
||||||
|
if not self.tool_filter:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(self.tool_filter, ToolPredicate):
|
||||||
|
return self.tool_filter(tool, readonly_context)
|
||||||
|
|
||||||
|
if isinstance(self.tool_filter, list):
|
||||||
|
return tool.name in self.tool_filter
|
||||||
|
|
||||||
|
return False
|
||||||
|
@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset):
|
|||||||
self._openapi_toolset = self._load_toolset_with_oidc_auth()
|
self._openapi_toolset = self._load_toolset_with_oidc_auth()
|
||||||
self.tool_filter = tool_filter
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
def _is_tool_selected(
|
|
||||||
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
|
|
||||||
) -> bool:
|
|
||||||
if not self.tool_filter:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if isinstance(self.tool_filter, ToolPredicate):
|
|
||||||
return self.tool_filter(tool, readonly_context)
|
|
||||||
|
|
||||||
if isinstance(self.tool_filter, list):
|
|
||||||
return tool.name in self.tool_filter
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
self, readonly_context: Optional[ReadonlyContext] = None
|
self, readonly_context: Optional[ReadonlyContext] = None
|
||||||
|
@ -90,6 +90,7 @@ class MCPToolset(BaseToolset):
|
|||||||
|
|
||||||
if not connection_params:
|
if not connection_params:
|
||||||
raise ValueError("Missing connection params in MCPToolset.")
|
raise ValueError("Missing connection params in MCPToolset.")
|
||||||
|
super().__init__(tool_filter=tool_filter)
|
||||||
self._connection_params = connection_params
|
self._connection_params = connection_params
|
||||||
self._errlog = errlog
|
self._errlog = errlog
|
||||||
self._exit_stack = AsyncExitStack()
|
self._exit_stack = AsyncExitStack()
|
||||||
@ -102,7 +103,6 @@ class MCPToolset(BaseToolset):
|
|||||||
errlog=self._errlog,
|
errlog=self._errlog,
|
||||||
)
|
)
|
||||||
self._session = None
|
self._session = None
|
||||||
self.tool_filter = tool_filter
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
async def _initialize(self) -> ClientSession:
|
async def _initialize(self) -> ClientSession:
|
||||||
@ -116,18 +116,6 @@ class MCPToolset(BaseToolset):
|
|||||||
self._initialized = True
|
self._initialized = True
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def _is_selected(
|
|
||||||
self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
|
|
||||||
) -> bool:
|
|
||||||
"""Checks if a tool should be selected based on the tool filter."""
|
|
||||||
if self.tool_filter is None:
|
|
||||||
return True
|
|
||||||
if isinstance(self.tool_filter, ToolPredicate):
|
|
||||||
return self.tool_filter(tool, readonly_context)
|
|
||||||
if isinstance(self.tool_filter, list):
|
|
||||||
return tool.name in self.tool_filter
|
|
||||||
return False
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
|
"""Safely closes the connection to MCP Server with guaranteed resource cleanup."""
|
||||||
@ -221,6 +209,6 @@ class MCPToolset(BaseToolset):
|
|||||||
mcp_session_manager=self._session_manager,
|
mcp_session_manager=self._session_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._is_selected(mcp_tool, readonly_context):
|
if self._is_tool_selected(mcp_tool, readonly_context):
|
||||||
tools.append(mcp_tool)
|
tools.append(mcp_tool)
|
||||||
return tools
|
return tools
|
||||||
|
@ -103,12 +103,12 @@ class OpenAPIToolset(BaseToolset):
|
|||||||
tool_filter: The filter used to filter the tools in the toolset. It can be
|
tool_filter: The filter used to filter the tools in the toolset. It can be
|
||||||
either a tool predicate or a list of tool names of the tools to expose.
|
either a tool predicate or a list of tool names of the tools to expose.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(tool_filter=tool_filter)
|
||||||
if not spec_dict:
|
if not spec_dict:
|
||||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||||
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||||
if auth_scheme or auth_credential:
|
if auth_scheme or auth_credential:
|
||||||
self._configure_auth_all(auth_scheme, auth_credential)
|
self._configure_auth_all(auth_scheme, auth_credential)
|
||||||
self.tool_filter = tool_filter
|
|
||||||
|
|
||||||
def _configure_auth_all(
|
def _configure_auth_all(
|
||||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||||
@ -129,12 +129,7 @@ class OpenAPIToolset(BaseToolset):
|
|||||||
return [
|
return [
|
||||||
tool
|
tool
|
||||||
for tool in self._tools
|
for tool in self._tools
|
||||||
if self.tool_filter is None
|
if self._is_tool_selected(tool, readonly_context)
|
||||||
or (
|
|
||||||
self.tool_filter(tool, readonly_context)
|
|
||||||
if isinstance(self.tool_filter, ToolPredicate)
|
|
||||||
else tool.name in self.tool_filter
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||||
|
Loading…
Reference in New Issue
Block a user