mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-22 13:22:19 -06:00
adapt google api toolset and api hub toolset to new toolset interface
PiperOrigin-RevId: 757946541
This commit is contained in:
committed by
Copybara-Service
parent
27b229719e
commit
6a04ff84ba
@@ -20,18 +20,23 @@ from typing import Final
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import override
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from ....agents.readonly_context import ReadonlyContext
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ...base_toolset import BaseToolset
|
||||
from ...base_toolset import ToolPredicate
|
||||
from .openapi_spec_parser import OpenApiSpecParser
|
||||
from .rest_api_tool import RestApiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAPIToolset:
|
||||
class OpenAPIToolset(BaseToolset):
|
||||
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
||||
|
||||
Usage:
|
||||
@@ -61,6 +66,7 @@ class OpenAPIToolset:
|
||||
spec_str_type: Literal["json", "yaml"] = "json",
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
):
|
||||
"""Initializes the OpenAPIToolset.
|
||||
|
||||
@@ -94,12 +100,15 @@ class OpenAPIToolset:
|
||||
auth_credential: The auth credential to use for all tools. Use
|
||||
AuthCredential or use helpers in
|
||||
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
tool_filter: The filter used to filter the tools in the toolset. It can be
|
||||
either a tool predicate or a list of tool names of the tools to expose
|
||||
"""
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
def _configure_auth_all(
|
||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||
@@ -112,9 +121,21 @@ class OpenAPIToolset:
|
||||
if auth_credential:
|
||||
tool.configure_auth_credential(auth_credential)
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> List[RestApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
return [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if self.tool_filter is None
|
||||
or (
|
||||
self.tool_filter(tool, readonly_context)
|
||||
if isinstance(self.tool_filter, ToolPredicate)
|
||||
else tool.name in self.tool_filter
|
||||
)
|
||||
]
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||
"""Get a tool by name."""
|
||||
@@ -142,3 +163,7 @@ class OpenAPIToolset:
|
||||
logger.info("Parsed tool: %s", tool.name)
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user