adapt google api toolset and api hub toolset to new toolset interface

PiperOrigin-RevId: 757946541
This commit is contained in:
Xiang (Sean) Zhou
2025-05-12 15:57:56 -07:00
committed by Copybara-Service
parent 27b229719e
commit 6a04ff84ba
7 changed files with 168 additions and 122 deletions

View File

@@ -20,18 +20,23 @@ from typing import Final
from typing import List
from typing import Literal
from typing import Optional
from typing import override
from typing import Union
import yaml
from ....agents.readonly_context import ReadonlyContext
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ...base_toolset import BaseToolset
from ...base_toolset import ToolPredicate
from .openapi_spec_parser import OpenApiSpecParser
from .rest_api_tool import RestApiTool
logger = logging.getLogger(__name__)
class OpenAPIToolset:
class OpenAPIToolset(BaseToolset):
"""Class for parsing OpenAPI spec into a list of RestApiTool.
Usage:
@@ -61,6 +66,7 @@ class OpenAPIToolset:
spec_str_type: Literal["json", "yaml"] = "json",
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
"""Initializes the OpenAPIToolset.
@@ -94,12 +100,15 @@ class OpenAPIToolset:
auth_credential: The auth credential to use for all tools. Use
AuthCredential or use helpers in
`google.adk.tools.openapi_tool.auth.auth_helpers`
tool_filter: The filter used to filter the tools in the toolset. It can be
either a tool predicate or a list of tool names of the tools to expose
"""
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
@@ -112,9 +121,21 @@ class OpenAPIToolset:
if auth_credential:
tool.configure_auth_credential(auth_credential)
def get_tools(self) -> List[RestApiTool]:
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[RestApiTool]:
"""Get all tools in the toolset."""
return self.tools
return [
tool
for tool in self.tools
if self.tool_filter is None
or (
self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, ToolPredicate)
else tool.name in self.tool_filter
)
]
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name."""
@@ -142,3 +163,7 @@ class OpenAPIToolset:
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
return tools
@override
async def close(self):
pass