mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 09:51:25 -06:00
adapt google api toolset and api hub toolset to new toolset interface
PiperOrigin-RevId: 757946541
This commit is contained in:
parent
27b229719e
commit
6a04ff84ba
@ -13,19 +13,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import override
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from ...agents.readonly_context import ReadonlyContext
|
||||||
from ...auth.auth_credential import AuthCredential
|
from ...auth.auth_credential import AuthCredential
|
||||||
from ...auth.auth_schemes import AuthScheme
|
from ...auth.auth_schemes import AuthScheme
|
||||||
|
from ..base_toolset import BaseToolset
|
||||||
|
from ..base_toolset import ToolPredicate
|
||||||
from ..openapi_tool.common.common import to_snake_case
|
from ..openapi_tool.common.common import to_snake_case
|
||||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||||
from .clients.apihub_client import APIHubClient
|
from .clients.apihub_client import APIHubClient
|
||||||
|
|
||||||
|
|
||||||
class APIHubToolset:
|
class APIHubToolset(BaseToolset):
|
||||||
"""APIHubTool generates tools from a given API Hub resource.
|
"""APIHubTool generates tools from a given API Hub resource.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -34,16 +40,13 @@ class APIHubToolset:
|
|||||||
apihub_toolset = APIHubToolset(
|
apihub_toolset = APIHubToolset(
|
||||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||||
service_account_json="...",
|
service_account_json="...",
|
||||||
|
tool_filter=lambda tool, ctx=None: tool.name in ('my_tool',
|
||||||
|
'my_other_tool')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all available tools
|
# Get all available tools
|
||||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
agent = LlmAgent(tools=apihub_toolset)
|
||||||
|
|
||||||
# Get a specific tool
|
|
||||||
agent = LlmAgent(tools=[
|
|
||||||
...
|
|
||||||
apihub_toolset.get_tool('my_tool'),
|
|
||||||
])
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||||
@ -70,6 +73,7 @@ class APIHubToolset:
|
|||||||
auth_credential: Optional[AuthCredential] = None,
|
auth_credential: Optional[AuthCredential] = None,
|
||||||
# Optionally, you can provide a custom API Hub client
|
# Optionally, you can provide a custom API Hub client
|
||||||
apihub_client: Optional[APIHubClient] = None,
|
apihub_client: Optional[APIHubClient] = None,
|
||||||
|
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||||
):
|
):
|
||||||
"""Initializes the APIHubTool with the given parameters.
|
"""Initializes the APIHubTool with the given parameters.
|
||||||
|
|
||||||
@ -81,12 +85,17 @@ class APIHubToolset:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get all available tools
|
# Get all available tools
|
||||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
agent = LlmAgent(tools=[apihub_toolset])
|
||||||
|
|
||||||
|
apihub_toolset = APIHubToolset(
|
||||||
|
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||||
|
service_account_json="...",
|
||||||
|
tool_filter = ['my_tool']
|
||||||
|
)
|
||||||
# Get a specific tool
|
# Get a specific tool
|
||||||
agent = LlmAgent(tools=[
|
agent = LlmAgent(tools=[
|
||||||
...
|
...,
|
||||||
apihub_toolset.get_tool('my_tool'),
|
apihub_toolset,
|
||||||
])
|
])
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -118,6 +127,9 @@ class APIHubToolset:
|
|||||||
lazy_load_spec: If True, the spec will be loaded lazily when needed.
|
lazy_load_spec: If True, the spec will be loaded lazily when needed.
|
||||||
Otherwise, the spec will be loaded immediately and the tools will be
|
Otherwise, the spec will be loaded immediately and the tools will be
|
||||||
generated during initialization.
|
generated during initialization.
|
||||||
|
tool_filter: The filter used to filter the tools in the toolset. It can
|
||||||
|
be either a tool predicate or a list of tool names of the tools to
|
||||||
|
expose.
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
@ -128,72 +140,36 @@ class APIHubToolset:
|
|||||||
service_account_json=service_account_json,
|
service_account_json=service_account_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
self.openapi_toolset = None
|
||||||
self.auth_scheme = auth_scheme
|
self.auth_scheme = auth_scheme
|
||||||
self.auth_credential = auth_credential
|
self.auth_credential = auth_credential
|
||||||
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
if not self.lazy_load_spec:
|
if not self.lazy_load_spec:
|
||||||
self._prepare_tools()
|
self._prepare_toolset()
|
||||||
|
|
||||||
def get_tool(self, name: str) -> Optional[RestApiTool]:
|
@override
|
||||||
"""Retrieves a specific tool by its name.
|
async def get_tools(
|
||||||
|
self, readonly_context: Optional[ReadonlyContext] = None
|
||||||
Example:
|
) -> List[RestApiTool]:
|
||||||
```
|
|
||||||
apihub_tool = apihub_toolset.get_tool('my_tool')
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The name of the tool to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The tool with the given name, or None if no such tool exists.
|
|
||||||
"""
|
|
||||||
if not self._are_tools_ready():
|
|
||||||
self._prepare_tools()
|
|
||||||
|
|
||||||
return self.generated_tools[name] if name in self.generated_tools else None
|
|
||||||
|
|
||||||
def get_tools(self) -> List[RestApiTool]:
|
|
||||||
"""Retrieves all available tools.
|
"""Retrieves all available tools.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of all available RestApiTool objects.
|
A list of all available RestApiTool objects.
|
||||||
"""
|
"""
|
||||||
if not self._are_tools_ready():
|
if not self.openapi_toolset:
|
||||||
self._prepare_tools()
|
self._prepare_toolset()
|
||||||
|
if not self.openapi_toolset:
|
||||||
|
return []
|
||||||
|
return await self.openapi_toolset.get_tools(readonly_context)
|
||||||
|
|
||||||
return list(self.generated_tools.values())
|
def _prepare_toolset(self) -> None:
|
||||||
|
"""Fetches the spec from API Hub and generates the toolset."""
|
||||||
def _are_tools_ready(self) -> bool:
|
|
||||||
return not self.lazy_load_spec or self.generated_tools
|
|
||||||
|
|
||||||
def _prepare_tools(self) -> str:
|
|
||||||
"""Fetches the spec from API Hub and generates the tools.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the tools are ready, False otherwise.
|
|
||||||
"""
|
|
||||||
# For each API, get the first version and the first spec of that version.
|
# For each API, get the first version and the first spec of that version.
|
||||||
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
spec_str = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
||||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
|
||||||
|
|
||||||
tools = self._parse_spec_to_tools(spec)
|
|
||||||
for tool in tools:
|
|
||||||
self.generated_tools[tool.name] = tool
|
|
||||||
|
|
||||||
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
|
|
||||||
"""Parses the spec string to a list of RestApiTool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec_str: The spec string to parse.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of RestApiTool objects.
|
|
||||||
"""
|
|
||||||
spec_dict = yaml.safe_load(spec_str)
|
spec_dict = yaml.safe_load(spec_str)
|
||||||
if not spec_dict:
|
if not spec_dict:
|
||||||
return []
|
return
|
||||||
|
|
||||||
self.name = self.name or to_snake_case(
|
self.name = self.name or to_snake_case(
|
||||||
spec_dict.get('info', {}).get('title', 'unnamed')
|
spec_dict.get('info', {}).get('title', 'unnamed')
|
||||||
@ -201,9 +177,14 @@ class APIHubToolset:
|
|||||||
self.description = self.description or spec_dict.get('info', {}).get(
|
self.description = self.description or spec_dict.get('info', {}).get(
|
||||||
'description', ''
|
'description', ''
|
||||||
)
|
)
|
||||||
tools = OpenAPIToolset(
|
self.openapi_toolset = OpenAPIToolset(
|
||||||
spec_dict=spec_dict,
|
spec_dict=spec_dict,
|
||||||
auth_credential=self.auth_credential,
|
auth_credential=self.auth_credential,
|
||||||
auth_scheme=self.auth_scheme,
|
auth_scheme=self.auth_scheme,
|
||||||
).get_tools()
|
tool_filter=self.tool_filter,
|
||||||
return tools
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def close(self):
|
||||||
|
if self.openapi_toolset:
|
||||||
|
await self.openapi_toolset.close()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from google.adk.agents.readonly_context import ReadonlyContext
|
from google.adk.agents.readonly_context import ReadonlyContext
|
||||||
@ -33,7 +34,7 @@ class BaseToolset(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
self, readony_context: ReadonlyContext = None
|
self, readonly_context: Optional[ReadonlyContext] = None
|
||||||
) -> list[BaseTool]:
|
) -> list[BaseTool]:
|
||||||
"""Return all tools in the toolset based on the provided context.
|
"""Return all tools in the toolset based on the provided context.
|
||||||
|
|
||||||
|
@ -17,37 +17,67 @@ from __future__ import annotations
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Final
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import override
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ...agents.readonly_context import ReadonlyContext
|
||||||
from ...auth import OpenIdConnectWithConfig
|
from ...auth import OpenIdConnectWithConfig
|
||||||
|
from ...tools.base_toolset import BaseToolset
|
||||||
|
from ...tools.base_toolset import ToolPredicate
|
||||||
from ..openapi_tool import OpenAPIToolset
|
from ..openapi_tool import OpenAPIToolset
|
||||||
from ..openapi_tool import RestApiTool
|
|
||||||
from .google_api_tool import GoogleApiTool
|
from .google_api_tool import GoogleApiTool
|
||||||
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||||
|
|
||||||
|
|
||||||
class GoogleApiToolSet:
|
class GoogleApiToolset(BaseToolset):
|
||||||
"""Google API Tool Set."""
|
"""Google API Toolset contains tools for interacting with Google APIs.
|
||||||
|
|
||||||
def __init__(self, tools: List[RestApiTool]):
|
Usually one toolsets will contains tools only replated to one Google API, e.g.
|
||||||
self.tools: Final[List[GoogleApiTool]] = [
|
Google Bigquery API toolset will contains tools only related to Google
|
||||||
GoogleApiTool(tool) for tool in tools
|
Bigquery API, like list dataset tool, list table tool etc.
|
||||||
]
|
"""
|
||||||
|
|
||||||
def get_tools(self) -> List[GoogleApiTool]:
|
def __init__(
|
||||||
|
self,
|
||||||
|
openapi_toolset: OpenAPIToolset,
|
||||||
|
client_id: Optional[str] = None,
|
||||||
|
client_secret: Optional[str] = None,
|
||||||
|
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||||
|
):
|
||||||
|
self.openapi_toolset = openapi_toolset
|
||||||
|
self.tool_filter = tool_filter
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_tools(
|
||||||
|
self, readonly_context: Optional[ReadonlyContext] = None
|
||||||
|
) -> List[GoogleApiTool]:
|
||||||
"""Get all tools in the toolset."""
|
"""Get all tools in the toolset."""
|
||||||
return self.tools
|
tools = []
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
|
for tool in await self.openapi_toolset.get_tools(readonly_context):
|
||||||
"""Get a tool by name."""
|
if self.tool_filter and (
|
||||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
isinstance(self.tool_filter, ToolPredicate)
|
||||||
return next(matching_tool, None)
|
and not self.tool_filter(tool, readonly_context)
|
||||||
|
or isinstance(self.tool_filter, list)
|
||||||
|
and tool.name not in self.tool_filter
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
google_api_tool = GoogleApiTool(tool)
|
||||||
|
google_api_tool.configure_auth(self.client_id, self.client_secret)
|
||||||
|
tools.append(google_api_tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]):
|
||||||
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_tool_set_with_oidc_auth(
|
def _load_toolset_with_oidc_auth(
|
||||||
spec_file: Optional[str] = None,
|
spec_file: Optional[str] = None,
|
||||||
spec_dict: Optional[dict[str, Any]] = None,
|
spec_dict: Optional[dict[str, Any]] = None,
|
||||||
scopes: Optional[list[str]] = None,
|
scopes: Optional[list[str]] = None,
|
||||||
@ -64,7 +94,7 @@ class GoogleApiToolSet:
|
|||||||
yaml_path = os.path.join(caller_dir, spec_file)
|
yaml_path = os.path.join(caller_dir, spec_file)
|
||||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||||
spec_str = file.read()
|
spec_str = file.read()
|
||||||
tool_set = OpenAPIToolset(
|
toolset = OpenAPIToolset(
|
||||||
spec_dict=spec_dict,
|
spec_dict=spec_dict,
|
||||||
spec_str=spec_str,
|
spec_str=spec_str,
|
||||||
spec_str_type='yaml',
|
spec_str_type='yaml',
|
||||||
@ -85,18 +115,18 @@ class GoogleApiToolSet:
|
|||||||
scopes=scopes,
|
scopes=scopes,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return tool_set
|
return toolset
|
||||||
|
|
||||||
def configure_auth(self, client_id: str, client_secret: str):
|
def configure_auth(self, client_id: str, client_secret: str):
|
||||||
for tool in self.tools:
|
self.client_id = client_id
|
||||||
tool.configure_auth(client_id, client_secret)
|
self.client_secret = client_secret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_tool_set(
|
def load_toolset(
|
||||||
cls: Type[GoogleApiToolSet],
|
cls: Type[GoogleApiToolset],
|
||||||
api_name: str,
|
api_name: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
) -> GoogleApiToolSet:
|
) -> GoogleApiToolset:
|
||||||
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
||||||
scope = list(
|
scope = list(
|
||||||
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
||||||
@ -104,7 +134,10 @@ class GoogleApiToolSet:
|
|||||||
]['scopes'].keys()
|
]['scopes'].keys()
|
||||||
)[0]
|
)[0]
|
||||||
return cls(
|
return cls(
|
||||||
cls._load_tool_set_with_oidc_auth(
|
cls._load_toolset_with_oidc_auth(spec_dict=spec_dict, scopes=[scope])
|
||||||
spec_dict=spec_dict, scopes=[scope]
|
|
||||||
).get_tools()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def close(self):
|
||||||
|
if self.openapi_toolset:
|
||||||
|
await self.openapi_toolset.close()
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .google_api_tool_set import GoogleApiToolSet
|
from .google_api_tool_set import GoogleApiToolset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ def __getattr__(name):
|
|||||||
match name:
|
match name:
|
||||||
case "bigquery_tool_set":
|
case "bigquery_tool_set":
|
||||||
if _bigquery_tool_set is None:
|
if _bigquery_tool_set is None:
|
||||||
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
_bigquery_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="bigquery",
|
api_name="bigquery",
|
||||||
api_version="v2",
|
api_version="v2",
|
||||||
)
|
)
|
||||||
@ -59,7 +59,7 @@ def __getattr__(name):
|
|||||||
|
|
||||||
case "calendar_tool_set":
|
case "calendar_tool_set":
|
||||||
if _calendar_tool_set is None:
|
if _calendar_tool_set is None:
|
||||||
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
_calendar_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="calendar",
|
api_name="calendar",
|
||||||
api_version="v3",
|
api_version="v3",
|
||||||
)
|
)
|
||||||
@ -68,7 +68,7 @@ def __getattr__(name):
|
|||||||
|
|
||||||
case "gmail_tool_set":
|
case "gmail_tool_set":
|
||||||
if _gmail_tool_set is None:
|
if _gmail_tool_set is None:
|
||||||
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
_gmail_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="gmail",
|
api_name="gmail",
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
)
|
)
|
||||||
@ -77,7 +77,7 @@ def __getattr__(name):
|
|||||||
|
|
||||||
case "youtube_tool_set":
|
case "youtube_tool_set":
|
||||||
if _youtube_tool_set is None:
|
if _youtube_tool_set is None:
|
||||||
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
_youtube_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="youtube",
|
api_name="youtube",
|
||||||
api_version="v3",
|
api_version="v3",
|
||||||
)
|
)
|
||||||
@ -86,7 +86,7 @@ def __getattr__(name):
|
|||||||
|
|
||||||
case "slides_tool_set":
|
case "slides_tool_set":
|
||||||
if _slides_tool_set is None:
|
if _slides_tool_set is None:
|
||||||
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
_slides_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="slides",
|
api_name="slides",
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
)
|
)
|
||||||
@ -95,7 +95,7 @@ def __getattr__(name):
|
|||||||
|
|
||||||
case "sheets_tool_set":
|
case "sheets_tool_set":
|
||||||
if _sheets_tool_set is None:
|
if _sheets_tool_set is None:
|
||||||
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
_sheets_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="sheets",
|
api_name="sheets",
|
||||||
api_version="v4",
|
api_version="v4",
|
||||||
)
|
)
|
||||||
@ -104,7 +104,7 @@ def __getattr__(name):
|
|||||||
|
|
||||||
case "docs_tool_set":
|
case "docs_tool_set":
|
||||||
if _docs_tool_set is None:
|
if _docs_tool_set is None:
|
||||||
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
_docs_tool_set = GoogleApiToolset.load_toolset(
|
||||||
api_name="docs",
|
api_name="docs",
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ class MCPToolset(BaseToolset):
|
|||||||
*,
|
*,
|
||||||
connection_params: StdioServerParameters | SseServerParams,
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
errlog: TextIO = sys.stderr,
|
errlog: TextIO = sys.stderr,
|
||||||
tool_predicate: Optional[ToolPredicate] = None,
|
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||||
):
|
):
|
||||||
"""Initializes the MCPToolset.
|
"""Initializes the MCPToolset.
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ class MCPToolset(BaseToolset):
|
|||||||
errlog=self.errlog,
|
errlog=self.errlog,
|
||||||
)
|
)
|
||||||
self.session = None
|
self.session = None
|
||||||
self.tool_predicate = tool_predicate
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
async def _initialize(self) -> ClientSession:
|
async def _initialize(self) -> ClientSession:
|
||||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||||
@ -106,7 +106,7 @@ class MCPToolset(BaseToolset):
|
|||||||
@override
|
@override
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
self,
|
self,
|
||||||
readony_context: ReadonlyContext = None,
|
readonly_context: Optional[ReadonlyContext] = None,
|
||||||
) -> List[MCPTool]:
|
) -> List[MCPTool]:
|
||||||
"""Loads all tools from the MCP Server.
|
"""Loads all tools from the MCP Server.
|
||||||
|
|
||||||
@ -123,6 +123,5 @@ class MCPToolset(BaseToolset):
|
|||||||
mcp_session_manager=self.session_manager,
|
mcp_session_manager=self.session_manager,
|
||||||
)
|
)
|
||||||
for tool in tools_response.tools
|
for tool in tools_response.tools
|
||||||
if self.tool_predicate is None
|
if self.tool_filter is None or self.tool_filter(tool, readonly_context)
|
||||||
or self.tool_predicate(tool, readony_context)
|
|
||||||
]
|
]
|
||||||
|
@ -20,18 +20,23 @@ from typing import Final
|
|||||||
from typing import List
|
from typing import List
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import override
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from ....agents.readonly_context import ReadonlyContext
|
||||||
from ....auth.auth_credential import AuthCredential
|
from ....auth.auth_credential import AuthCredential
|
||||||
from ....auth.auth_schemes import AuthScheme
|
from ....auth.auth_schemes import AuthScheme
|
||||||
|
from ...base_toolset import BaseToolset
|
||||||
|
from ...base_toolset import ToolPredicate
|
||||||
from .openapi_spec_parser import OpenApiSpecParser
|
from .openapi_spec_parser import OpenApiSpecParser
|
||||||
from .rest_api_tool import RestApiTool
|
from .rest_api_tool import RestApiTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIToolset:
|
class OpenAPIToolset(BaseToolset):
|
||||||
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@ -61,6 +66,7 @@ class OpenAPIToolset:
|
|||||||
spec_str_type: Literal["json", "yaml"] = "json",
|
spec_str_type: Literal["json", "yaml"] = "json",
|
||||||
auth_scheme: Optional[AuthScheme] = None,
|
auth_scheme: Optional[AuthScheme] = None,
|
||||||
auth_credential: Optional[AuthCredential] = None,
|
auth_credential: Optional[AuthCredential] = None,
|
||||||
|
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||||
):
|
):
|
||||||
"""Initializes the OpenAPIToolset.
|
"""Initializes the OpenAPIToolset.
|
||||||
|
|
||||||
@ -94,12 +100,15 @@ class OpenAPIToolset:
|
|||||||
auth_credential: The auth credential to use for all tools. Use
|
auth_credential: The auth credential to use for all tools. Use
|
||||||
AuthCredential or use helpers in
|
AuthCredential or use helpers in
|
||||||
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||||
|
tool_filter: The filter used to filter the tools in the toolset. It can be
|
||||||
|
either a tool predicate or a list of tool names of the tools to expose
|
||||||
"""
|
"""
|
||||||
if not spec_dict:
|
if not spec_dict:
|
||||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||||
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||||
if auth_scheme or auth_credential:
|
if auth_scheme or auth_credential:
|
||||||
self._configure_auth_all(auth_scheme, auth_credential)
|
self._configure_auth_all(auth_scheme, auth_credential)
|
||||||
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
def _configure_auth_all(
|
def _configure_auth_all(
|
||||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||||
@ -112,9 +121,21 @@ class OpenAPIToolset:
|
|||||||
if auth_credential:
|
if auth_credential:
|
||||||
tool.configure_auth_credential(auth_credential)
|
tool.configure_auth_credential(auth_credential)
|
||||||
|
|
||||||
def get_tools(self) -> List[RestApiTool]:
|
@override
|
||||||
|
async def get_tools(
|
||||||
|
self, readonly_context: Optional[ReadonlyContext] = None
|
||||||
|
) -> List[RestApiTool]:
|
||||||
"""Get all tools in the toolset."""
|
"""Get all tools in the toolset."""
|
||||||
return self.tools
|
return [
|
||||||
|
tool
|
||||||
|
for tool in self.tools
|
||||||
|
if self.tool_filter is None
|
||||||
|
or (
|
||||||
|
self.tool_filter(tool, readonly_context)
|
||||||
|
if isinstance(self.tool_filter, ToolPredicate)
|
||||||
|
else tool.name in self.tool_filter
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||||
"""Get a tool by name."""
|
"""Get a tool by name."""
|
||||||
@ -142,3 +163,7 @@ class OpenAPIToolset:
|
|||||||
logger.info("Parsed tool: %s", tool.name)
|
logger.info("Parsed tool: %s", tool.name)
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def close(self):
|
||||||
|
pass
|
||||||
|
@ -77,22 +77,26 @@ def mock_auth_credential():
|
|||||||
|
|
||||||
|
|
||||||
# Test cases
|
# Test cases
|
||||||
def test_apihub_toolset_initialization(basic_apihub_toolset):
|
@pytest.mark.asyncio
|
||||||
|
async def test_apihub_toolset_initialization(basic_apihub_toolset):
|
||||||
assert basic_apihub_toolset.name == 'mock_api'
|
assert basic_apihub_toolset.name == 'mock_api'
|
||||||
assert basic_apihub_toolset.description == 'Mock API Description'
|
assert basic_apihub_toolset.description == 'Mock API Description'
|
||||||
assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
|
assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
|
||||||
assert not basic_apihub_toolset.lazy_load_spec
|
assert not basic_apihub_toolset.lazy_load_spec
|
||||||
assert len(basic_apihub_toolset.generated_tools) == 1
|
generated_tools = await basic_apihub_toolset.get_tools()
|
||||||
assert 'test_get' in basic_apihub_toolset.generated_tools
|
assert len(generated_tools) == 1
|
||||||
|
assert 'test_get' == generated_tools[0].name
|
||||||
|
|
||||||
|
|
||||||
def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
|
@pytest.mark.asyncio
|
||||||
|
async def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
|
||||||
assert lazy_apihub_toolset.lazy_load_spec
|
assert lazy_apihub_toolset.lazy_load_spec
|
||||||
assert not lazy_apihub_toolset.generated_tools
|
generated_tools = await lazy_apihub_toolset.get_tools()
|
||||||
|
assert generated_tools
|
||||||
|
|
||||||
tools = lazy_apihub_toolset.get_tools()
|
tools = await lazy_apihub_toolset.get_tools()
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
assert lazy_apihub_toolset.get_tool('test_get') == tools[0]
|
'test_get' == tools[0].name
|
||||||
|
|
||||||
|
|
||||||
def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
|
def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
|
||||||
@ -155,7 +159,8 @@ paths:
|
|||||||
assert toolset.description == ''
|
assert toolset.description == ''
|
||||||
|
|
||||||
|
|
||||||
def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
||||||
apihub_client = MockAPIHubClient()
|
apihub_client = MockAPIHubClient()
|
||||||
tool = APIHubToolset(
|
tool = APIHubToolset(
|
||||||
apihub_resource_name='test_resource',
|
apihub_resource_name='test_resource',
|
||||||
@ -163,11 +168,12 @@ def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
|||||||
auth_scheme=mock_auth_scheme,
|
auth_scheme=mock_auth_scheme,
|
||||||
auth_credential=mock_auth_credential,
|
auth_credential=mock_auth_credential,
|
||||||
)
|
)
|
||||||
tools = tool.get_tools()
|
tools = await tool.get_tools()
|
||||||
assert len(tools) == 1
|
assert len(tools) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
@pytest.mark.asyncio
|
||||||
|
async def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
||||||
|
|
||||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||||
|
|
||||||
@ -180,11 +186,12 @@ def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
|||||||
apihub_client=apihub_client,
|
apihub_client=apihub_client,
|
||||||
lazy_load_spec=True,
|
lazy_load_spec=True,
|
||||||
)
|
)
|
||||||
tools = tool.get_tools()
|
tools = await tool.get_tools()
|
||||||
assert not tools
|
assert not tools
|
||||||
|
|
||||||
|
|
||||||
def test_apihub_toolset_get_tools_invalid_yaml():
|
@pytest.mark.asyncio
|
||||||
|
async def test_apihub_toolset_get_tools_invalid_yaml():
|
||||||
|
|
||||||
class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
|
class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
|
||||||
|
|
||||||
@ -197,7 +204,7 @@ def test_apihub_toolset_get_tools_invalid_yaml():
|
|||||||
apihub_resource_name='test_resource',
|
apihub_resource_name='test_resource',
|
||||||
apihub_client=apihub_client,
|
apihub_client=apihub_client,
|
||||||
)
|
)
|
||||||
tool.get_tools()
|
await tool.get_tools()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user