adapt google api toolset and api hub toolset to new toolset interface

PiperOrigin-RevId: 757946541
This commit is contained in:
Xiang (Sean) Zhou 2025-05-12 15:57:56 -07:00 committed by Copybara-Service
parent 27b229719e
commit 6a04ff84ba
7 changed files with 168 additions and 122 deletions

View File

@ -13,19 +13,25 @@
# limitations under the License.
from typing import Dict, List, Optional
from typing import List
from typing import Optional
from typing import override
from typing import Union
import yaml
from ...agents.readonly_context import ReadonlyContext
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ..base_toolset import BaseToolset
from ..base_toolset import ToolPredicate
from ..openapi_tool.common.common import to_snake_case
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.apihub_client import APIHubClient
class APIHubToolset:
class APIHubToolset(BaseToolset):
"""APIHubTool generates tools from a given API Hub resource.
Examples:
@ -34,16 +40,13 @@ class APIHubToolset:
apihub_toolset = APIHubToolset(
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
service_account_json="...",
tool_filter=lambda tool, ctx=None: tool.name in ('my_tool',
'my_other_tool')
)
# Get all available tools
agent = LlmAgent(tools=apihub_toolset.get_tools())
agent = LlmAgent(tools=apihub_toolset)
# Get a specific tool
agent = LlmAgent(tools=[
...
apihub_toolset.get_tool('my_tool'),
])
```
**apihub_resource_name** is the resource name from API Hub. It must include
@ -70,6 +73,7 @@ class APIHubToolset:
auth_credential: Optional[AuthCredential] = None,
# Optionally, you can provide a custom API Hub client
apihub_client: Optional[APIHubClient] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
"""Initializes the APIHubTool with the given parameters.
@ -81,12 +85,17 @@ class APIHubToolset:
)
# Get all available tools
agent = LlmAgent(tools=apihub_toolset.get_tools())
agent = LlmAgent(tools=[apihub_toolset])
apihub_toolset = APIHubToolset(
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
service_account_json="...",
tool_filter = ['my_tool']
)
# Get a specific tool
agent = LlmAgent(tools=[
...
apihub_toolset.get_tool('my_tool'),
...,
apihub_toolset,
])
```
@ -118,6 +127,9 @@ class APIHubToolset:
lazy_load_spec: If True, the spec will be loaded lazily when needed.
Otherwise, the spec will be loaded immediately and the tools will be
generated during initialization.
tool_filter: The filter used to filter the tools in the toolset. It can
be either a tool predicate or a list of tool names of the tools to
expose.
"""
self.name = name
self.description = description
@ -128,72 +140,36 @@ class APIHubToolset:
service_account_json=service_account_json,
)
self.generated_tools: Dict[str, RestApiTool] = {}
self.openapi_toolset = None
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
self.tool_filter = tool_filter
if not self.lazy_load_spec:
self._prepare_tools()
self._prepare_toolset()
def get_tool(self, name: str) -> Optional[RestApiTool]:
"""Retrieves a specific tool by its name.
Example:
```
apihub_tool = apihub_toolset.get_tool('my_tool')
```
Args:
name: The name of the tool to retrieve.
Returns:
The tool with the given name, or None if no such tool exists.
"""
if not self._are_tools_ready():
self._prepare_tools()
return self.generated_tools[name] if name in self.generated_tools else None
def get_tools(self) -> List[RestApiTool]:
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[RestApiTool]:
"""Retrieves all available tools.
Returns:
A list of all available RestApiTool objects.
"""
if not self._are_tools_ready():
self._prepare_tools()
if not self.openapi_toolset:
self._prepare_toolset()
if not self.openapi_toolset:
return []
return await self.openapi_toolset.get_tools(readonly_context)
return list(self.generated_tools.values())
def _are_tools_ready(self) -> bool:
return not self.lazy_load_spec or self.generated_tools
def _prepare_tools(self) -> str:
"""Fetches the spec from API Hub and generates the tools.
Returns:
True if the tools are ready, False otherwise.
"""
def _prepare_toolset(self) -> None:
"""Fetches the spec from API Hub and generates the toolset."""
# For each API, get the first version and the first spec of that version.
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
self.generated_tools: Dict[str, RestApiTool] = {}
tools = self._parse_spec_to_tools(spec)
for tool in tools:
self.generated_tools[tool.name] = tool
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
"""Parses the spec string to a list of RestApiTool.
Args:
spec_str: The spec string to parse.
Returns:
A list of RestApiTool objects.
"""
spec_str = self.apihub_client.get_spec_content(self.apihub_resource_name)
spec_dict = yaml.safe_load(spec_str)
if not spec_dict:
return []
return
self.name = self.name or to_snake_case(
spec_dict.get('info', {}).get('title', 'unnamed')
@ -201,9 +177,14 @@ class APIHubToolset:
self.description = self.description or spec_dict.get('info', {}).get(
'description', ''
)
tools = OpenAPIToolset(
self.openapi_toolset = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=self.auth_credential,
auth_scheme=self.auth_scheme,
).get_tools()
return tools
tool_filter=self.tool_filter,
)
@override
async def close(self):
if self.openapi_toolset:
await self.openapi_toolset.close()

View File

@ -1,5 +1,6 @@
from abc import ABC
from abc import abstractmethod
from typing import Optional
from typing import Protocol
from google.adk.agents.readonly_context import ReadonlyContext
@ -33,7 +34,7 @@ class BaseToolset(ABC):
@abstractmethod
async def get_tools(
self, readony_context: ReadonlyContext = None
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[BaseTool]:
"""Return all tools in the toolset based on the provided context.

View File

@ -17,37 +17,67 @@ from __future__ import annotations
import inspect
import os
from typing import Any
from typing import Final
from typing import List
from typing import Optional
from typing import override
from typing import Type
from typing import Union
from ...agents.readonly_context import ReadonlyContext
from ...auth import OpenIdConnectWithConfig
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from ..openapi_tool import OpenAPIToolset
from ..openapi_tool import RestApiTool
from .google_api_tool import GoogleApiTool
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
class GoogleApiToolSet:
"""Google API Tool Set."""
class GoogleApiToolset(BaseToolset):
"""Google API Toolset contains tools for interacting with Google APIs.
def __init__(self, tools: List[RestApiTool]):
self.tools: Final[List[GoogleApiTool]] = [
GoogleApiTool(tool) for tool in tools
]
Usually one toolsets will contains tools only replated to one Google API, e.g.
Google Bigquery API toolset will contains tools only related to Google
Bigquery API, like list dataset tool, list table tool etc.
"""
def get_tools(self) -> List[GoogleApiTool]:
def __init__(
self,
openapi_toolset: OpenAPIToolset,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
self.openapi_toolset = openapi_toolset
self.tool_filter = tool_filter
self.client_id = client_id
self.client_secret = client_secret
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[GoogleApiTool]:
"""Get all tools in the toolset."""
return self.tools
tools = []
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
"""Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
return next(matching_tool, None)
for tool in await self.openapi_toolset.get_tools(readonly_context):
if self.tool_filter and (
isinstance(self.tool_filter, ToolPredicate)
and not self.tool_filter(tool, readonly_context)
or isinstance(self.tool_filter, list)
and tool.name not in self.tool_filter
):
continue
google_api_tool = GoogleApiTool(tool)
google_api_tool.configure_auth(self.client_id, self.client_secret)
tools.append(google_api_tool)
return tools
def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]):
self.tool_filter = tool_filter
@staticmethod
def _load_tool_set_with_oidc_auth(
def _load_toolset_with_oidc_auth(
spec_file: Optional[str] = None,
spec_dict: Optional[dict[str, Any]] = None,
scopes: Optional[list[str]] = None,
@ -64,7 +94,7 @@ class GoogleApiToolSet:
yaml_path = os.path.join(caller_dir, spec_file)
with open(yaml_path, 'r', encoding='utf-8') as file:
spec_str = file.read()
tool_set = OpenAPIToolset(
toolset = OpenAPIToolset(
spec_dict=spec_dict,
spec_str=spec_str,
spec_str_type='yaml',
@ -85,18 +115,18 @@ class GoogleApiToolSet:
scopes=scopes,
),
)
return tool_set
return toolset
def configure_auth(self, client_id: str, client_secret: str):
for tool in self.tools:
tool.configure_auth(client_id, client_secret)
self.client_id = client_id
self.client_secret = client_secret
@classmethod
def load_tool_set(
cls: Type[GoogleApiToolSet],
def load_toolset(
cls: Type[GoogleApiToolset],
api_name: str,
api_version: str,
) -> GoogleApiToolSet:
) -> GoogleApiToolset:
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
scope = list(
spec_dict['components']['securitySchemes']['oauth2']['flows'][
@ -104,7 +134,10 @@ class GoogleApiToolSet:
]['scopes'].keys()
)[0]
return cls(
cls._load_tool_set_with_oidc_auth(
spec_dict=spec_dict, scopes=[scope]
).get_tools()
cls._load_toolset_with_oidc_auth(spec_dict=spec_dict, scopes=[scope])
)
@override
async def close(self):
if self.openapi_toolset:
await self.openapi_toolset.close()

View File

@ -15,7 +15,7 @@
import logging
from .google_api_tool_set import GoogleApiToolSet
from .google_api_tool_set import GoogleApiToolset
logger = logging.getLogger(__name__)
@ -50,7 +50,7 @@ def __getattr__(name):
match name:
case "bigquery_tool_set":
if _bigquery_tool_set is None:
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
_bigquery_tool_set = GoogleApiToolset.load_toolset(
api_name="bigquery",
api_version="v2",
)
@ -59,7 +59,7 @@ def __getattr__(name):
case "calendar_tool_set":
if _calendar_tool_set is None:
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
_calendar_tool_set = GoogleApiToolset.load_toolset(
api_name="calendar",
api_version="v3",
)
@ -68,7 +68,7 @@ def __getattr__(name):
case "gmail_tool_set":
if _gmail_tool_set is None:
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
_gmail_tool_set = GoogleApiToolset.load_toolset(
api_name="gmail",
api_version="v1",
)
@ -77,7 +77,7 @@ def __getattr__(name):
case "youtube_tool_set":
if _youtube_tool_set is None:
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
_youtube_tool_set = GoogleApiToolset.load_toolset(
api_name="youtube",
api_version="v3",
)
@ -86,7 +86,7 @@ def __getattr__(name):
case "slides_tool_set":
if _slides_tool_set is None:
_slides_tool_set = GoogleApiToolSet.load_tool_set(
_slides_tool_set = GoogleApiToolset.load_toolset(
api_name="slides",
api_version="v1",
)
@ -95,7 +95,7 @@ def __getattr__(name):
case "sheets_tool_set":
if _sheets_tool_set is None:
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
_sheets_tool_set = GoogleApiToolset.load_toolset(
api_name="sheets",
api_version="v4",
)
@ -104,7 +104,7 @@ def __getattr__(name):
case "docs_tool_set":
if _docs_tool_set is None:
_docs_tool_set = GoogleApiToolSet.load_tool_set(
_docs_tool_set = GoogleApiToolset.load_toolset(
api_name="docs",
api_version="v1",
)

View File

@ -14,7 +14,7 @@
from contextlib import AsyncExitStack
import sys
from typing import List
from typing import List, Union
from typing import Optional
from typing import TextIO
@ -68,7 +68,7 @@ class MCPToolset(BaseToolset):
*,
connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr,
tool_predicate: Optional[ToolPredicate] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
"""Initializes the MCPToolset.
@ -90,7 +90,7 @@ class MCPToolset(BaseToolset):
errlog=self.errlog,
)
self.session = None
self.tool_predicate = tool_predicate
self.tool_filter = tool_filter
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
@ -106,7 +106,7 @@ class MCPToolset(BaseToolset):
@override
async def get_tools(
self,
readony_context: ReadonlyContext = None,
readonly_context: Optional[ReadonlyContext] = None,
) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
@ -123,6 +123,5 @@ class MCPToolset(BaseToolset):
mcp_session_manager=self.session_manager,
)
for tool in tools_response.tools
if self.tool_predicate is None
or self.tool_predicate(tool, readony_context)
if self.tool_filter is None or self.tool_filter(tool, readonly_context)
]

View File

@ -20,18 +20,23 @@ from typing import Final
from typing import List
from typing import Literal
from typing import Optional
from typing import override
from typing import Union
import yaml
from ....agents.readonly_context import ReadonlyContext
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ...base_toolset import BaseToolset
from ...base_toolset import ToolPredicate
from .openapi_spec_parser import OpenApiSpecParser
from .rest_api_tool import RestApiTool
logger = logging.getLogger(__name__)
class OpenAPIToolset:
class OpenAPIToolset(BaseToolset):
"""Class for parsing OpenAPI spec into a list of RestApiTool.
Usage:
@ -61,6 +66,7 @@ class OpenAPIToolset:
spec_str_type: Literal["json", "yaml"] = "json",
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
"""Initializes the OpenAPIToolset.
@ -94,12 +100,15 @@ class OpenAPIToolset:
auth_credential: The auth credential to use for all tools. Use
AuthCredential or use helpers in
`google.adk.tools.openapi_tool.auth.auth_helpers`
tool_filter: The filter used to filter the tools in the toolset. It can be
either a tool predicate or a list of tool names of the tools to expose
"""
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
@ -112,9 +121,21 @@ class OpenAPIToolset:
if auth_credential:
tool.configure_auth_credential(auth_credential)
def get_tools(self) -> List[RestApiTool]:
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[RestApiTool]:
"""Get all tools in the toolset."""
return self.tools
return [
tool
for tool in self.tools
if self.tool_filter is None
or (
self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, ToolPredicate)
else tool.name in self.tool_filter
)
]
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name."""
@ -142,3 +163,7 @@ class OpenAPIToolset:
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
return tools
@override
async def close(self):
pass

View File

@ -77,22 +77,26 @@ def mock_auth_credential():
# Test cases
def test_apihub_toolset_initialization(basic_apihub_toolset):
@pytest.mark.asyncio
async def test_apihub_toolset_initialization(basic_apihub_toolset):
assert basic_apihub_toolset.name == 'mock_api'
assert basic_apihub_toolset.description == 'Mock API Description'
assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
assert not basic_apihub_toolset.lazy_load_spec
assert len(basic_apihub_toolset.generated_tools) == 1
assert 'test_get' in basic_apihub_toolset.generated_tools
generated_tools = await basic_apihub_toolset.get_tools()
assert len(generated_tools) == 1
assert 'test_get' == generated_tools[0].name
def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
@pytest.mark.asyncio
async def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
assert lazy_apihub_toolset.lazy_load_spec
assert not lazy_apihub_toolset.generated_tools
generated_tools = await lazy_apihub_toolset.get_tools()
assert generated_tools
tools = lazy_apihub_toolset.get_tools()
tools = await lazy_apihub_toolset.get_tools()
assert len(tools) == 1
assert lazy_apihub_toolset.get_tool('test_get') == tools[0]
'test_get' == tools[0].name
def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
@ -155,7 +159,8 @@ paths:
assert toolset.description == ''
def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
@pytest.mark.asyncio
async def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
apihub_client = MockAPIHubClient()
tool = APIHubToolset(
apihub_resource_name='test_resource',
@ -163,11 +168,12 @@ def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
auth_scheme=mock_auth_scheme,
auth_credential=mock_auth_credential,
)
tools = tool.get_tools()
tools = await tool.get_tools()
assert len(tools) == 1
def test_apihub_toolset_get_tools_lazy_load_empty_spec():
@pytest.mark.asyncio
async def test_apihub_toolset_get_tools_lazy_load_empty_spec():
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
@ -180,11 +186,12 @@ def test_apihub_toolset_get_tools_lazy_load_empty_spec():
apihub_client=apihub_client,
lazy_load_spec=True,
)
tools = tool.get_tools()
tools = await tool.get_tools()
assert not tools
def test_apihub_toolset_get_tools_invalid_yaml():
@pytest.mark.asyncio
async def test_apihub_toolset_get_tools_invalid_yaml():
class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
@ -197,7 +204,7 @@ def test_apihub_toolset_get_tools_invalid_yaml():
apihub_resource_name='test_resource',
apihub_client=apihub_client,
)
tool.get_tools()
await tool.get_tools()
if __name__ == '__main__':