adapt google api toolset and api hub toolset to new toolset interface

PiperOrigin-RevId: 757946541
This commit is contained in:
Xiang (Sean) Zhou
2025-05-12 15:57:56 -07:00
committed by Copybara-Service
parent 27b229719e
commit 6a04ff84ba
7 changed files with 168 additions and 122 deletions

View File

@@ -13,19 +13,25 @@
# limitations under the License.
from typing import Dict, List, Optional
from typing import List
from typing import Optional
from typing import override
from typing import Union
import yaml
from ...agents.readonly_context import ReadonlyContext
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ..base_toolset import BaseToolset
from ..base_toolset import ToolPredicate
from ..openapi_tool.common.common import to_snake_case
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.apihub_client import APIHubClient
class APIHubToolset:
class APIHubToolset(BaseToolset):
"""APIHubTool generates tools from a given API Hub resource.
Examples:
@@ -34,16 +40,13 @@ class APIHubToolset:
apihub_toolset = APIHubToolset(
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
service_account_json="...",
tool_filter=lambda tool, ctx=None: tool.name in ('my_tool',
'my_other_tool')
)
# Get all available tools
agent = LlmAgent(tools=apihub_toolset.get_tools())
agent = LlmAgent(tools=apihub_toolset)
# Get a specific tool
agent = LlmAgent(tools=[
...
apihub_toolset.get_tool('my_tool'),
])
```
**apihub_resource_name** is the resource name from API Hub. It must include
@@ -70,6 +73,7 @@ class APIHubToolset:
auth_credential: Optional[AuthCredential] = None,
# Optionally, you can provide a custom API Hub client
apihub_client: Optional[APIHubClient] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
"""Initializes the APIHubTool with the given parameters.
@@ -81,12 +85,17 @@ class APIHubToolset:
)
# Get all available tools
agent = LlmAgent(tools=apihub_toolset.get_tools())
agent = LlmAgent(tools=[apihub_toolset])
apihub_toolset = APIHubToolset(
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
service_account_json="...",
tool_filter = ['my_tool']
)
# Get a specific tool
agent = LlmAgent(tools=[
...
apihub_toolset.get_tool('my_tool'),
...,
apihub_toolset,
])
```
@@ -118,6 +127,9 @@ class APIHubToolset:
lazy_load_spec: If True, the spec will be loaded lazily when needed.
Otherwise, the spec will be loaded immediately and the tools will be
generated during initialization.
tool_filter: The filter used to filter the tools in the toolset. It can
be either a tool predicate or a list of tool names of the tools to
expose.
"""
self.name = name
self.description = description
@@ -128,72 +140,36 @@ class APIHubToolset:
service_account_json=service_account_json,
)
self.generated_tools: Dict[str, RestApiTool] = {}
self.openapi_toolset = None
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
self.tool_filter = tool_filter
if not self.lazy_load_spec:
self._prepare_tools()
self._prepare_toolset()
def get_tool(self, name: str) -> Optional[RestApiTool]:
"""Retrieves a specific tool by its name.
Example:
```
apihub_tool = apihub_toolset.get_tool('my_tool')
```
Args:
name: The name of the tool to retrieve.
Returns:
The tool with the given name, or None if no such tool exists.
"""
if not self._are_tools_ready():
self._prepare_tools()
return self.generated_tools[name] if name in self.generated_tools else None
def get_tools(self) -> List[RestApiTool]:
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[RestApiTool]:
"""Retrieves all available tools.
Returns:
A list of all available RestApiTool objects.
"""
if not self._are_tools_ready():
self._prepare_tools()
if not self.openapi_toolset:
self._prepare_toolset()
if not self.openapi_toolset:
return []
return await self.openapi_toolset.get_tools(readonly_context)
return list(self.generated_tools.values())
def _are_tools_ready(self) -> bool:
return not self.lazy_load_spec or self.generated_tools
def _prepare_tools(self) -> str:
"""Fetches the spec from API Hub and generates the tools.
Returns:
True if the tools are ready, False otherwise.
"""
def _prepare_toolset(self) -> None:
"""Fetches the spec from API Hub and generates the toolset."""
# For each API, get the first version and the first spec of that version.
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
self.generated_tools: Dict[str, RestApiTool] = {}
tools = self._parse_spec_to_tools(spec)
for tool in tools:
self.generated_tools[tool.name] = tool
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
"""Parses the spec string to a list of RestApiTool.
Args:
spec_str: The spec string to parse.
Returns:
A list of RestApiTool objects.
"""
spec_str = self.apihub_client.get_spec_content(self.apihub_resource_name)
spec_dict = yaml.safe_load(spec_str)
if not spec_dict:
return []
return
self.name = self.name or to_snake_case(
spec_dict.get('info', {}).get('title', 'unnamed')
@@ -201,9 +177,14 @@ class APIHubToolset:
self.description = self.description or spec_dict.get('info', {}).get(
'description', ''
)
tools = OpenAPIToolset(
self.openapi_toolset = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=self.auth_credential,
auth_scheme=self.auth_scheme,
).get_tools()
return tools
tool_filter=self.tool_filter,
)
@override
async def close(self):
if self.openapi_toolset:
await self.openapi_toolset.close()