diff --git a/src/google/adk/tools/apihub_tool/apihub_toolset.py b/src/google/adk/tools/apihub_tool/apihub_toolset.py index 0cf160e..8acf1b7 100644 --- a/src/google/adk/tools/apihub_tool/apihub_toolset.py +++ b/src/google/adk/tools/apihub_tool/apihub_toolset.py @@ -13,19 +13,25 @@ # limitations under the License. -from typing import Dict, List, Optional +from typing import List +from typing import Optional +from typing import override +from typing import Union import yaml +from ...agents.readonly_context import ReadonlyContext from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme +from ..base_toolset import BaseToolset +from ..base_toolset import ToolPredicate from ..openapi_tool.common.common import to_snake_case from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool from .clients.apihub_client import APIHubClient -class APIHubToolset: +class APIHubToolset(BaseToolset): """APIHubTool generates tools from a given API Hub resource. Examples: @@ -34,16 +40,13 @@ class APIHubToolset: apihub_toolset = APIHubToolset( apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api", service_account_json="...", + tool_filter=lambda tool, ctx=None: tool.name in ('my_tool', + 'my_other_tool') ) # Get all available tools - agent = LlmAgent(tools=apihub_toolset.get_tools()) + agent = LlmAgent(tools=apihub_toolset) - # Get a specific tool - agent = LlmAgent(tools=[ - ... - apihub_toolset.get_tool('my_tool'), - ]) ``` **apihub_resource_name** is the resource name from API Hub. It must include @@ -70,6 +73,7 @@ class APIHubToolset: auth_credential: Optional[AuthCredential] = None, # Optionally, you can provide a custom API Hub client apihub_client: Optional[APIHubClient] = None, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, ): """Initializes the APIHubTool with the given parameters. @@ -81,12 +85,17 @@ class APIHubToolset: ) # Get all available tools - agent = LlmAgent(tools=apihub_toolset.get_tools()) + agent = LlmAgent(tools=[apihub_toolset]) + apihub_toolset = APIHubToolset( + apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api", + service_account_json="...", + tool_filter = ['my_tool'] + ) # Get a specific tool agent = LlmAgent(tools=[ - ... - apihub_toolset.get_tool('my_tool'), + ..., + apihub_toolset, ]) ``` @@ -118,6 +127,9 @@ class APIHubToolset: lazy_load_spec: If True, the spec will be loaded lazily when needed. Otherwise, the spec will be loaded immediately and the tools will be generated during initialization. + tool_filter: The filter used to filter the tools in the toolset. It can + be either a tool predicate or a list of tool names of the tools to + expose. """ self.name = name self.description = description @@ -128,72 +140,36 @@ class APIHubToolset: service_account_json=service_account_json, ) - self.generated_tools: Dict[str, RestApiTool] = {} + self.openapi_toolset = None self.auth_scheme = auth_scheme self.auth_credential = auth_credential + self.tool_filter = tool_filter if not self.lazy_load_spec: - self._prepare_tools() + self._prepare_toolset() - def get_tool(self, name: str) -> Optional[RestApiTool]: - """Retrieves a specific tool by its name. - - Example: - ``` - apihub_tool = apihub_toolset.get_tool('my_tool') - ``` - - Args: - name: The name of the tool to retrieve. - - Returns: - The tool with the given name, or None if no such tool exists. - """ - if not self._are_tools_ready(): - self._prepare_tools() - - return self.generated_tools[name] if name in self.generated_tools else None - - def get_tools(self) -> List[RestApiTool]: + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[RestApiTool]: """Retrieves all available tools. Returns: A list of all available RestApiTool objects. """ - if not self._are_tools_ready(): - self._prepare_tools() + if not self.openapi_toolset: + self._prepare_toolset() + if not self.openapi_toolset: + return [] + return await self.openapi_toolset.get_tools(readonly_context) - return list(self.generated_tools.values()) - - def _are_tools_ready(self) -> bool: - return not self.lazy_load_spec or self.generated_tools - - def _prepare_tools(self) -> str: - """Fetches the spec from API Hub and generates the tools. - - Returns: - True if the tools are ready, False otherwise. - """ + def _prepare_toolset(self) -> None: + """Fetches the spec from API Hub and generates the toolset.""" # For each API, get the first version and the first spec of that version. - spec = self.apihub_client.get_spec_content(self.apihub_resource_name) - self.generated_tools: Dict[str, RestApiTool] = {} - - tools = self._parse_spec_to_tools(spec) - for tool in tools: - self.generated_tools[tool.name] = tool - - def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]: - """Parses the spec string to a list of RestApiTool. - - Args: - spec_str: The spec string to parse. - - Returns: - A list of RestApiTool objects. - """ + spec_str = self.apihub_client.get_spec_content(self.apihub_resource_name) spec_dict = yaml.safe_load(spec_str) if not spec_dict: - return [] + return self.name = self.name or to_snake_case( spec_dict.get('info', {}).get('title', 'unnamed') @@ -201,9 +177,14 @@ class APIHubToolset: self.description = self.description or spec_dict.get('info', {}).get( 'description', '' ) - tools = OpenAPIToolset( + self.openapi_toolset = OpenAPIToolset( spec_dict=spec_dict, auth_credential=self.auth_credential, auth_scheme=self.auth_scheme, - ).get_tools() - return tools + tool_filter=self.tool_filter, + ) + + @override + async def close(self): + if self.openapi_toolset: + await self.openapi_toolset.close() diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index d603a8e..10ea2fa 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -1,5 +1,6 @@ from abc import ABC from abc import abstractmethod +from typing import Optional from typing import Protocol from google.adk.agents.readonly_context import ReadonlyContext @@ -33,7 +34,7 @@ class BaseToolset(ABC): @abstractmethod async def get_tools( - self, readony_context: ReadonlyContext = None + self, readonly_context: Optional[ReadonlyContext] = None ) -> list[BaseTool]: """Return all tools in the toolset based on the provided context. diff --git a/src/google/adk/tools/google_api_tool/google_api_tool_set.py b/src/google/adk/tools/google_api_tool/google_api_tool_set.py index 5409593..7707106 100644 --- a/src/google/adk/tools/google_api_tool/google_api_tool_set.py +++ b/src/google/adk/tools/google_api_tool/google_api_tool_set.py @@ -17,37 +17,67 @@ from __future__ import annotations import inspect import os from typing import Any -from typing import Final from typing import List from typing import Optional +from typing import override from typing import Type +from typing import Union +from ...agents.readonly_context import ReadonlyContext from ...auth import OpenIdConnectWithConfig +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate from ..openapi_tool import OpenAPIToolset -from ..openapi_tool import RestApiTool from .google_api_tool import GoogleApiTool from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter -class GoogleApiToolSet: - """Google API Tool Set.""" +class GoogleApiToolset(BaseToolset): + """Google API Toolset contains tools for interacting with Google APIs. - def __init__(self, tools: List[RestApiTool]): - self.tools: Final[List[GoogleApiTool]] = [ - GoogleApiTool(tool) for tool in tools - ] + Usually one toolsets will contains tools only replated to one Google API, e.g. + Google Bigquery API toolset will contains tools only related to Google + Bigquery API, like list dataset tool, list table tool etc. + """ - def get_tools(self) -> List[GoogleApiTool]: + def __init__( + self, + openapi_toolset: OpenAPIToolset, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + ): + self.openapi_toolset = openapi_toolset + self.tool_filter = tool_filter + self.client_id = client_id + self.client_secret = client_secret + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[GoogleApiTool]: """Get all tools in the toolset.""" - return self.tools + tools = [] - def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]: - """Get a tool by name.""" - matching_tool = filter(lambda t: t.name == tool_name, self.tools) - return next(matching_tool, None) + for tool in await self.openapi_toolset.get_tools(readonly_context): + if self.tool_filter and ( + isinstance(self.tool_filter, ToolPredicate) + and not self.tool_filter(tool, readonly_context) + or isinstance(self.tool_filter, list) + and tool.name not in self.tool_filter + ): + continue + google_api_tool = GoogleApiTool(tool) + google_api_tool.configure_auth(self.client_id, self.client_secret) + tools.append(google_api_tool) + + return tools + + def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]): + self.tool_filter = tool_filter @staticmethod - def _load_tool_set_with_oidc_auth( + def _load_toolset_with_oidc_auth( spec_file: Optional[str] = None, spec_dict: Optional[dict[str, Any]] = None, scopes: Optional[list[str]] = None, @@ -64,7 +94,7 @@ class GoogleApiToolSet: yaml_path = os.path.join(caller_dir, spec_file) with open(yaml_path, 'r', encoding='utf-8') as file: spec_str = file.read() - tool_set = OpenAPIToolset( + toolset = OpenAPIToolset( spec_dict=spec_dict, spec_str=spec_str, spec_str_type='yaml', @@ -85,18 +115,18 @@ class GoogleApiToolSet: scopes=scopes, ), ) - return tool_set + return toolset def configure_auth(self, client_id: str, client_secret: str): - for tool in self.tools: - tool.configure_auth(client_id, client_secret) + self.client_id = client_id + self.client_secret = client_secret @classmethod - def load_tool_set( - cls: Type[GoogleApiToolSet], + def load_toolset( + cls: Type[GoogleApiToolset], api_name: str, api_version: str, - ) -> GoogleApiToolSet: + ) -> GoogleApiToolset: spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert() scope = list( spec_dict['components']['securitySchemes']['oauth2']['flows'][ @@ -104,7 +134,10 @@ class GoogleApiToolSet: ]['scopes'].keys() )[0] return cls( - cls._load_tool_set_with_oidc_auth( - spec_dict=spec_dict, scopes=[scope] - ).get_tools() + cls._load_toolset_with_oidc_auth(spec_dict=spec_dict, scopes=[scope]) ) + + @override + async def close(self): + if self.openapi_toolset: + await self.openapi_toolset.close() diff --git a/src/google/adk/tools/google_api_tool/google_api_tool_sets.py b/src/google/adk/tools/google_api_tool/google_api_tool_sets.py index 5b099d7..6835b02 100644 --- a/src/google/adk/tools/google_api_tool/google_api_tool_sets.py +++ b/src/google/adk/tools/google_api_tool/google_api_tool_sets.py @@ -15,7 +15,7 @@ import logging -from .google_api_tool_set import GoogleApiToolSet +from .google_api_tool_set import GoogleApiToolset logger = logging.getLogger(__name__) @@ -50,7 +50,7 @@ def __getattr__(name): match name: case "bigquery_tool_set": if _bigquery_tool_set is None: - _bigquery_tool_set = GoogleApiToolSet.load_tool_set( + _bigquery_tool_set = GoogleApiToolset.load_toolset( api_name="bigquery", api_version="v2", ) @@ -59,7 +59,7 @@ def __getattr__(name): case "calendar_tool_set": if _calendar_tool_set is None: - _calendar_tool_set = GoogleApiToolSet.load_tool_set( + _calendar_tool_set = GoogleApiToolset.load_toolset( api_name="calendar", api_version="v3", ) @@ -68,7 +68,7 @@ def __getattr__(name): case "gmail_tool_set": if _gmail_tool_set is None: - _gmail_tool_set = GoogleApiToolSet.load_tool_set( + _gmail_tool_set = GoogleApiToolset.load_toolset( api_name="gmail", api_version="v1", ) @@ -77,7 +77,7 @@ def __getattr__(name): case "youtube_tool_set": if _youtube_tool_set is None: - _youtube_tool_set = GoogleApiToolSet.load_tool_set( + _youtube_tool_set = GoogleApiToolset.load_toolset( api_name="youtube", api_version="v3", ) @@ -86,7 +86,7 @@ def __getattr__(name): case "slides_tool_set": if _slides_tool_set is None: - _slides_tool_set = GoogleApiToolSet.load_tool_set( + _slides_tool_set = GoogleApiToolset.load_toolset( api_name="slides", api_version="v1", ) @@ -95,7 +95,7 @@ def __getattr__(name): case "sheets_tool_set": if _sheets_tool_set is None: - _sheets_tool_set = GoogleApiToolSet.load_tool_set( + _sheets_tool_set = GoogleApiToolset.load_toolset( api_name="sheets", api_version="v4", ) @@ -104,7 +104,7 @@ def __getattr__(name): case "docs_tool_set": if _docs_tool_set is None: - _docs_tool_set = GoogleApiToolSet.load_tool_set( + _docs_tool_set = GoogleApiToolset.load_toolset( api_name="docs", api_version="v1", ) diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 541ab53..4beb25a 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -14,7 +14,7 @@ from contextlib import AsyncExitStack import sys -from typing import List +from typing import List, Union from typing import Optional from typing import TextIO @@ -68,7 +68,7 @@ class MCPToolset(BaseToolset): *, connection_params: StdioServerParameters | SseServerParams, errlog: TextIO = sys.stderr, - tool_predicate: Optional[ToolPredicate] = None, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, ): """Initializes the MCPToolset. @@ -90,7 +90,7 @@ class MCPToolset(BaseToolset): errlog=self.errlog, ) self.session = None - self.tool_predicate = tool_predicate + self.tool_filter = tool_filter async def _initialize(self) -> ClientSession: """Connects to the MCP Server and initializes the ClientSession.""" @@ -106,7 +106,7 @@ class MCPToolset(BaseToolset): @override async def get_tools( self, - readony_context: ReadonlyContext = None, + readonly_context: Optional[ReadonlyContext] = None, ) -> List[MCPTool]: """Loads all tools from the MCP Server. @@ -123,6 +123,5 @@ class MCPToolset(BaseToolset): mcp_session_manager=self.session_manager, ) for tool in tools_response.tools - if self.tool_predicate is None - or self.tool_predicate(tool, readony_context) + if self.tool_filter is None or self.tool_filter(tool, readonly_context) ] diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py index 6bd0f08..1cef0fe 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py @@ -20,18 +20,23 @@ from typing import Final from typing import List from typing import Literal from typing import Optional +from typing import override +from typing import Union import yaml +from ....agents.readonly_context import ReadonlyContext from ....auth.auth_credential import AuthCredential from ....auth.auth_schemes import AuthScheme +from ...base_toolset import BaseToolset +from ...base_toolset import ToolPredicate from .openapi_spec_parser import OpenApiSpecParser from .rest_api_tool import RestApiTool logger = logging.getLogger(__name__) -class OpenAPIToolset: +class OpenAPIToolset(BaseToolset): """Class for parsing OpenAPI spec into a list of RestApiTool. Usage: @@ -61,6 +66,7 @@ class OpenAPIToolset: spec_str_type: Literal["json", "yaml"] = "json", auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, ): """Initializes the OpenAPIToolset. @@ -94,12 +100,15 @@ class OpenAPIToolset: auth_credential: The auth credential to use for all tools. Use AuthCredential or use helpers in `google.adk.tools.openapi_tool.auth.auth_helpers` + tool_filter: The filter used to filter the tools in the toolset. It can be + either a tool predicate or a list of tool names of the tools to expose """ if not spec_dict: spec_dict = self._load_spec(spec_str, spec_str_type) self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict)) if auth_scheme or auth_credential: self._configure_auth_all(auth_scheme, auth_credential) + self.tool_filter = tool_filter def _configure_auth_all( self, auth_scheme: AuthScheme, auth_credential: AuthCredential @@ -112,9 +121,21 @@ class OpenAPIToolset: if auth_credential: tool.configure_auth_credential(auth_credential) - def get_tools(self) -> List[RestApiTool]: + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[RestApiTool]: """Get all tools in the toolset.""" - return self.tools + return [ + tool + for tool in self.tools + if self.tool_filter is None + or ( + self.tool_filter(tool, readonly_context) + if isinstance(self.tool_filter, ToolPredicate) + else tool.name in self.tool_filter + ) + ] def get_tool(self, tool_name: str) -> Optional[RestApiTool]: """Get a tool by name.""" @@ -142,3 +163,7 @@ class OpenAPIToolset: logger.info("Parsed tool: %s", tool.name) tools.append(tool) return tools + + @override + async def close(self): + pass diff --git a/tests/unittests/tools/apihub_tool/test_apihub_toolset.py b/tests/unittests/tools/apihub_tool/test_apihub_toolset.py index 9ec68fa..139c0c4 100644 --- a/tests/unittests/tools/apihub_tool/test_apihub_toolset.py +++ b/tests/unittests/tools/apihub_tool/test_apihub_toolset.py @@ -77,22 +77,26 @@ def mock_auth_credential(): # Test cases -def test_apihub_toolset_initialization(basic_apihub_toolset): +@pytest.mark.asyncio +async def test_apihub_toolset_initialization(basic_apihub_toolset): assert basic_apihub_toolset.name == 'mock_api' assert basic_apihub_toolset.description == 'Mock API Description' assert basic_apihub_toolset.apihub_resource_name == 'test_resource' assert not basic_apihub_toolset.lazy_load_spec - assert len(basic_apihub_toolset.generated_tools) == 1 - assert 'test_get' in basic_apihub_toolset.generated_tools + generated_tools = await basic_apihub_toolset.get_tools() + assert len(generated_tools) == 1 + assert 'test_get' == generated_tools[0].name -def test_apihub_toolset_lazy_loading(lazy_apihub_toolset): +@pytest.mark.asyncio +async def test_apihub_toolset_lazy_loading(lazy_apihub_toolset): assert lazy_apihub_toolset.lazy_load_spec - assert not lazy_apihub_toolset.generated_tools + generated_tools = await lazy_apihub_toolset.get_tools() + assert generated_tools - tools = lazy_apihub_toolset.get_tools() + tools = await lazy_apihub_toolset.get_tools() assert len(tools) == 1 - assert lazy_apihub_toolset.get_tool('test_get') == tools[0] + 'test_get' == tools[0].name def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset): @@ -155,7 +159,8 @@ paths: assert toolset.description == '' -def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential): +@pytest.mark.asyncio +async def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential): apihub_client = MockAPIHubClient() tool = APIHubToolset( apihub_resource_name='test_resource', @@ -163,11 +168,12 @@ def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential): auth_scheme=mock_auth_scheme, auth_credential=mock_auth_credential, ) - tools = tool.get_tools() + tools = await tool.get_tools() assert len(tools) == 1 -def test_apihub_toolset_get_tools_lazy_load_empty_spec(): +@pytest.mark.asyncio +async def test_apihub_toolset_get_tools_lazy_load_empty_spec(): class MockAPIHubClientEmptySpec(BaseAPIHubClient): @@ -180,11 +186,12 @@ def test_apihub_toolset_get_tools_lazy_load_empty_spec(): apihub_client=apihub_client, lazy_load_spec=True, ) - tools = tool.get_tools() + tools = await tool.get_tools() assert not tools -def test_apihub_toolset_get_tools_invalid_yaml(): +@pytest.mark.asyncio +async def test_apihub_toolset_get_tools_invalid_yaml(): class MockAPIHubClientInvalidYAML(BaseAPIHubClient): @@ -197,7 +204,7 @@ def test_apihub_toolset_get_tools_invalid_yaml(): apihub_resource_name='test_resource', apihub_client=apihub_client, ) - tool.get_tools() + await tool.get_tools() if __name__ == '__main__':