mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
adapt google api toolset and api hub toolset to new toolset interface
PiperOrigin-RevId: 757946541
This commit is contained in:
parent
27b229719e
commit
6a04ff84ba
@ -13,19 +13,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import override
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..base_toolset import BaseToolset
|
||||
from ..base_toolset import ToolPredicate
|
||||
from ..openapi_tool.common.common import to_snake_case
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from .clients.apihub_client import APIHubClient
|
||||
|
||||
|
||||
class APIHubToolset:
|
||||
class APIHubToolset(BaseToolset):
|
||||
"""APIHubTool generates tools from a given API Hub resource.
|
||||
|
||||
Examples:
|
||||
@ -34,16 +40,13 @@ class APIHubToolset:
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
tool_filter=lambda tool, ctx=None: tool.name in ('my_tool',
|
||||
'my_other_tool')
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
agent = LlmAgent(tools=apihub_toolset)
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
@ -70,6 +73,7 @@ class APIHubToolset:
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
# Optionally, you can provide a custom API Hub client
|
||||
apihub_client: Optional[APIHubClient] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
):
|
||||
"""Initializes the APIHubTool with the given parameters.
|
||||
|
||||
@ -81,12 +85,17 @@ class APIHubToolset:
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
agent = LlmAgent(tools=[apihub_toolset])
|
||||
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
tool_filter = ['my_tool']
|
||||
)
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
...,
|
||||
apihub_toolset,
|
||||
])
|
||||
```
|
||||
|
||||
@ -118,6 +127,9 @@ class APIHubToolset:
|
||||
lazy_load_spec: If True, the spec will be loaded lazily when needed.
|
||||
Otherwise, the spec will be loaded immediately and the tools will be
|
||||
generated during initialization.
|
||||
tool_filter: The filter used to filter the tools in the toolset. It can
|
||||
be either a tool predicate or a list of tool names of the tools to
|
||||
expose.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
@ -128,72 +140,36 @@ class APIHubToolset:
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
self.openapi_toolset = None
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
if not self.lazy_load_spec:
|
||||
self._prepare_tools()
|
||||
self._prepare_toolset()
|
||||
|
||||
def get_tool(self, name: str) -> Optional[RestApiTool]:
|
||||
"""Retrieves a specific tool by its name.
|
||||
|
||||
Example:
|
||||
```
|
||||
apihub_tool = apihub_toolset.get_tool('my_tool')
|
||||
```
|
||||
|
||||
Args:
|
||||
name: The name of the tool to retrieve.
|
||||
|
||||
Returns:
|
||||
The tool with the given name, or None if no such tool exists.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return self.generated_tools[name] if name in self.generated_tools else None
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> List[RestApiTool]:
|
||||
"""Retrieves all available tools.
|
||||
|
||||
Returns:
|
||||
A list of all available RestApiTool objects.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
if not self.openapi_toolset:
|
||||
self._prepare_toolset()
|
||||
if not self.openapi_toolset:
|
||||
return []
|
||||
return await self.openapi_toolset.get_tools(readonly_context)
|
||||
|
||||
return list(self.generated_tools.values())
|
||||
|
||||
def _are_tools_ready(self) -> bool:
|
||||
return not self.lazy_load_spec or self.generated_tools
|
||||
|
||||
def _prepare_tools(self) -> str:
|
||||
"""Fetches the spec from API Hub and generates the tools.
|
||||
|
||||
Returns:
|
||||
True if the tools are ready, False otherwise.
|
||||
"""
|
||||
def _prepare_toolset(self) -> None:
|
||||
"""Fetches the spec from API Hub and generates the toolset."""
|
||||
# For each API, get the first version and the first spec of that version.
|
||||
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
|
||||
tools = self._parse_spec_to_tools(spec)
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
|
||||
"""Parses the spec string to a list of RestApiTool.
|
||||
|
||||
Args:
|
||||
spec_str: The spec string to parse.
|
||||
|
||||
Returns:
|
||||
A list of RestApiTool objects.
|
||||
"""
|
||||
spec_str = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
||||
spec_dict = yaml.safe_load(spec_str)
|
||||
if not spec_dict:
|
||||
return []
|
||||
return
|
||||
|
||||
self.name = self.name or to_snake_case(
|
||||
spec_dict.get('info', {}).get('title', 'unnamed')
|
||||
@ -201,9 +177,14 @@ class APIHubToolset:
|
||||
self.description = self.description or spec_dict.get('info', {}).get(
|
||||
'description', ''
|
||||
)
|
||||
tools = OpenAPIToolset(
|
||||
self.openapi_toolset = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=self.auth_credential,
|
||||
auth_scheme=self.auth_scheme,
|
||||
).get_tools()
|
||||
return tools
|
||||
tool_filter=self.tool_filter,
|
||||
)
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
if self.openapi_toolset:
|
||||
await self.openapi_toolset.close()
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
@ -33,7 +34,7 @@ class BaseToolset(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_tools(
|
||||
self, readony_context: ReadonlyContext = None
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> list[BaseTool]:
|
||||
"""Return all tools in the toolset based on the provided context.
|
||||
|
||||
|
@ -17,37 +17,67 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import override
|
||||
from typing import Type
|
||||
from typing import Union
|
||||
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...auth import OpenIdConnectWithConfig
|
||||
from ...tools.base_toolset import BaseToolset
|
||||
from ...tools.base_toolset import ToolPredicate
|
||||
from ..openapi_tool import OpenAPIToolset
|
||||
from ..openapi_tool import RestApiTool
|
||||
from .google_api_tool import GoogleApiTool
|
||||
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||
|
||||
|
||||
class GoogleApiToolSet:
|
||||
"""Google API Tool Set."""
|
||||
class GoogleApiToolset(BaseToolset):
|
||||
"""Google API Toolset contains tools for interacting with Google APIs.
|
||||
|
||||
def __init__(self, tools: List[RestApiTool]):
|
||||
self.tools: Final[List[GoogleApiTool]] = [
|
||||
GoogleApiTool(tool) for tool in tools
|
||||
]
|
||||
Usually one toolsets will contains tools only replated to one Google API, e.g.
|
||||
Google Bigquery API toolset will contains tools only related to Google
|
||||
Bigquery API, like list dataset tool, list table tool etc.
|
||||
"""
|
||||
|
||||
def get_tools(self) -> List[GoogleApiTool]:
|
||||
def __init__(
|
||||
self,
|
||||
openapi_toolset: OpenAPIToolset,
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
):
|
||||
self.openapi_toolset = openapi_toolset
|
||||
self.tool_filter = tool_filter
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> List[GoogleApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
tools = []
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
|
||||
"""Get a tool by name."""
|
||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
||||
return next(matching_tool, None)
|
||||
for tool in await self.openapi_toolset.get_tools(readonly_context):
|
||||
if self.tool_filter and (
|
||||
isinstance(self.tool_filter, ToolPredicate)
|
||||
and not self.tool_filter(tool, readonly_context)
|
||||
or isinstance(self.tool_filter, list)
|
||||
and tool.name not in self.tool_filter
|
||||
):
|
||||
continue
|
||||
google_api_tool = GoogleApiTool(tool)
|
||||
google_api_tool.configure_auth(self.client_id, self.client_secret)
|
||||
tools.append(google_api_tool)
|
||||
|
||||
return tools
|
||||
|
||||
def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]):
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
@staticmethod
|
||||
def _load_tool_set_with_oidc_auth(
|
||||
def _load_toolset_with_oidc_auth(
|
||||
spec_file: Optional[str] = None,
|
||||
spec_dict: Optional[dict[str, Any]] = None,
|
||||
scopes: Optional[list[str]] = None,
|
||||
@ -64,7 +94,7 @@ class GoogleApiToolSet:
|
||||
yaml_path = os.path.join(caller_dir, spec_file)
|
||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||
spec_str = file.read()
|
||||
tool_set = OpenAPIToolset(
|
||||
toolset = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
spec_str=spec_str,
|
||||
spec_str_type='yaml',
|
||||
@ -85,18 +115,18 @@ class GoogleApiToolSet:
|
||||
scopes=scopes,
|
||||
),
|
||||
)
|
||||
return tool_set
|
||||
return toolset
|
||||
|
||||
def configure_auth(self, client_id: str, client_secret: str):
|
||||
for tool in self.tools:
|
||||
tool.configure_auth(client_id, client_secret)
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
@classmethod
|
||||
def load_tool_set(
|
||||
cls: Type[GoogleApiToolSet],
|
||||
def load_toolset(
|
||||
cls: Type[GoogleApiToolset],
|
||||
api_name: str,
|
||||
api_version: str,
|
||||
) -> GoogleApiToolSet:
|
||||
) -> GoogleApiToolset:
|
||||
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
||||
scope = list(
|
||||
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
||||
@ -104,7 +134,10 @@ class GoogleApiToolSet:
|
||||
]['scopes'].keys()
|
||||
)[0]
|
||||
return cls(
|
||||
cls._load_tool_set_with_oidc_auth(
|
||||
spec_dict=spec_dict, scopes=[scope]
|
||||
).get_tools()
|
||||
cls._load_toolset_with_oidc_auth(spec_dict=spec_dict, scopes=[scope])
|
||||
)
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
if self.openapi_toolset:
|
||||
await self.openapi_toolset.close()
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from .google_api_tool_set import GoogleApiToolSet
|
||||
from .google_api_tool_set import GoogleApiToolset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,7 +50,7 @@ def __getattr__(name):
|
||||
match name:
|
||||
case "bigquery_tool_set":
|
||||
if _bigquery_tool_set is None:
|
||||
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_bigquery_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="bigquery",
|
||||
api_version="v2",
|
||||
)
|
||||
@ -59,7 +59,7 @@ def __getattr__(name):
|
||||
|
||||
case "calendar_tool_set":
|
||||
if _calendar_tool_set is None:
|
||||
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_calendar_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="calendar",
|
||||
api_version="v3",
|
||||
)
|
||||
@ -68,7 +68,7 @@ def __getattr__(name):
|
||||
|
||||
case "gmail_tool_set":
|
||||
if _gmail_tool_set is None:
|
||||
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_gmail_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="gmail",
|
||||
api_version="v1",
|
||||
)
|
||||
@ -77,7 +77,7 @@ def __getattr__(name):
|
||||
|
||||
case "youtube_tool_set":
|
||||
if _youtube_tool_set is None:
|
||||
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_youtube_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="youtube",
|
||||
api_version="v3",
|
||||
)
|
||||
@ -86,7 +86,7 @@ def __getattr__(name):
|
||||
|
||||
case "slides_tool_set":
|
||||
if _slides_tool_set is None:
|
||||
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_slides_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="slides",
|
||||
api_version="v1",
|
||||
)
|
||||
@ -95,7 +95,7 @@ def __getattr__(name):
|
||||
|
||||
case "sheets_tool_set":
|
||||
if _sheets_tool_set is None:
|
||||
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_sheets_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="sheets",
|
||||
api_version="v4",
|
||||
)
|
||||
@ -104,7 +104,7 @@ def __getattr__(name):
|
||||
|
||||
case "docs_tool_set":
|
||||
if _docs_tool_set is None:
|
||||
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
_docs_tool_set = GoogleApiToolset.load_toolset(
|
||||
api_name="docs",
|
||||
api_version="v1",
|
||||
)
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
from typing import Optional
|
||||
from typing import TextIO
|
||||
|
||||
@ -68,7 +68,7 @@ class MCPToolset(BaseToolset):
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
errlog: TextIO = sys.stderr,
|
||||
tool_predicate: Optional[ToolPredicate] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
):
|
||||
"""Initializes the MCPToolset.
|
||||
|
||||
@ -90,7 +90,7 @@ class MCPToolset(BaseToolset):
|
||||
errlog=self.errlog,
|
||||
)
|
||||
self.session = None
|
||||
self.tool_predicate = tool_predicate
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
async def _initialize(self) -> ClientSession:
|
||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||
@ -106,7 +106,7 @@ class MCPToolset(BaseToolset):
|
||||
@override
|
||||
async def get_tools(
|
||||
self,
|
||||
readony_context: ReadonlyContext = None,
|
||||
readonly_context: Optional[ReadonlyContext] = None,
|
||||
) -> List[MCPTool]:
|
||||
"""Loads all tools from the MCP Server.
|
||||
|
||||
@ -123,6 +123,5 @@ class MCPToolset(BaseToolset):
|
||||
mcp_session_manager=self.session_manager,
|
||||
)
|
||||
for tool in tools_response.tools
|
||||
if self.tool_predicate is None
|
||||
or self.tool_predicate(tool, readony_context)
|
||||
if self.tool_filter is None or self.tool_filter(tool, readonly_context)
|
||||
]
|
||||
|
@ -20,18 +20,23 @@ from typing import Final
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import override
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from ....agents.readonly_context import ReadonlyContext
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ...base_toolset import BaseToolset
|
||||
from ...base_toolset import ToolPredicate
|
||||
from .openapi_spec_parser import OpenApiSpecParser
|
||||
from .rest_api_tool import RestApiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAPIToolset:
|
||||
class OpenAPIToolset(BaseToolset):
|
||||
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
||||
|
||||
Usage:
|
||||
@ -61,6 +66,7 @@ class OpenAPIToolset:
|
||||
spec_str_type: Literal["json", "yaml"] = "json",
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
):
|
||||
"""Initializes the OpenAPIToolset.
|
||||
|
||||
@ -94,12 +100,15 @@ class OpenAPIToolset:
|
||||
auth_credential: The auth credential to use for all tools. Use
|
||||
AuthCredential or use helpers in
|
||||
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
tool_filter: The filter used to filter the tools in the toolset. It can be
|
||||
either a tool predicate or a list of tool names of the tools to expose
|
||||
"""
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
def _configure_auth_all(
|
||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||
@ -112,9 +121,21 @@ class OpenAPIToolset:
|
||||
if auth_credential:
|
||||
tool.configure_auth_credential(auth_credential)
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> List[RestApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
return [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if self.tool_filter is None
|
||||
or (
|
||||
self.tool_filter(tool, readonly_context)
|
||||
if isinstance(self.tool_filter, ToolPredicate)
|
||||
else tool.name in self.tool_filter
|
||||
)
|
||||
]
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||
"""Get a tool by name."""
|
||||
@ -142,3 +163,7 @@ class OpenAPIToolset:
|
||||
logger.info("Parsed tool: %s", tool.name)
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
pass
|
||||
|
@ -77,22 +77,26 @@ def mock_auth_credential():
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_apihub_toolset_initialization(basic_apihub_toolset):
|
||||
@pytest.mark.asyncio
|
||||
async def test_apihub_toolset_initialization(basic_apihub_toolset):
|
||||
assert basic_apihub_toolset.name == 'mock_api'
|
||||
assert basic_apihub_toolset.description == 'Mock API Description'
|
||||
assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
|
||||
assert not basic_apihub_toolset.lazy_load_spec
|
||||
assert len(basic_apihub_toolset.generated_tools) == 1
|
||||
assert 'test_get' in basic_apihub_toolset.generated_tools
|
||||
generated_tools = await basic_apihub_toolset.get_tools()
|
||||
assert len(generated_tools) == 1
|
||||
assert 'test_get' == generated_tools[0].name
|
||||
|
||||
|
||||
def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
|
||||
@pytest.mark.asyncio
|
||||
async def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
|
||||
assert lazy_apihub_toolset.lazy_load_spec
|
||||
assert not lazy_apihub_toolset.generated_tools
|
||||
generated_tools = await lazy_apihub_toolset.get_tools()
|
||||
assert generated_tools
|
||||
|
||||
tools = lazy_apihub_toolset.get_tools()
|
||||
tools = await lazy_apihub_toolset.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert lazy_apihub_toolset.get_tool('test_get') == tools[0]
|
||||
'test_get' == tools[0].name
|
||||
|
||||
|
||||
def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
|
||||
@ -155,7 +159,8 @@ paths:
|
||||
assert toolset.description == ''
|
||||
|
||||
|
||||
def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
@ -163,11 +168,12 @@ def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
||||
auth_scheme=mock_auth_scheme,
|
||||
auth_credential=mock_auth_credential,
|
||||
)
|
||||
tools = tool.get_tools()
|
||||
tools = await tool.get_tools()
|
||||
assert len(tools) == 1
|
||||
|
||||
|
||||
def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
||||
@pytest.mark.asyncio
|
||||
async def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
@ -180,11 +186,12 @@ def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
||||
apihub_client=apihub_client,
|
||||
lazy_load_spec=True,
|
||||
)
|
||||
tools = tool.get_tools()
|
||||
tools = await tool.get_tools()
|
||||
assert not tools
|
||||
|
||||
|
||||
def test_apihub_toolset_get_tools_invalid_yaml():
|
||||
@pytest.mark.asyncio
|
||||
async def test_apihub_toolset_get_tools_invalid_yaml():
|
||||
|
||||
class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
|
||||
|
||||
@ -197,7 +204,7 @@ def test_apihub_toolset_get_tools_invalid_yaml():
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
tool.get_tools()
|
||||
await tool.get_tools()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user