refactor: refactor google api toolset to expose class instead of instance

PiperOrigin-RevId: 759289358
This commit is contained in:
Xiang (Sean) Zhou 2025-05-15 13:59:19 -07:00 committed by Copybara-Service
parent f298d07579
commit bdd678db31
4 changed files with 124 additions and 197 deletions

View File

@ -12,75 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = [ __all__ = [
'bigquery_toolset', 'BigQueryToolset',
'calendar_toolset', 'CalendarToolset',
'gmail_toolset', 'GmailToolset',
'youtube_toolset', 'YoutubeToolset',
'slides_toolset', 'SlidesToolset',
'sheets_toolset', 'SheetsToolset',
'docs_toolset', 'DocsToolset',
'GoogleApiToolset',
'GoogleApiTool',
] ]
# Nothing is imported here automatically
# Each tool set will only be imported when accessed
_bigquery_toolset = None from .google_api_tool import GoogleApiTool
_calendar_toolset = None from .google_api_toolset import GoogleApiToolset
_gmail_toolset = None from .google_api_toolsets import BigQueryToolset
_youtube_toolset = None from .google_api_toolsets import CalendarToolset
_slides_toolset = None from .google_api_toolsets import DocsToolset
_sheets_toolset = None from .google_api_toolsets import GmailToolset
_docs_toolset = None from .google_api_toolsets import SheetsToolset
from .google_api_toolsets import SlidesToolset
from .google_api_toolsets import YoutubeToolset
def __getattr__(name):
global _bigquery_toolset, _calendar_toolset, _gmail_toolset, _youtube_toolset, _slides_toolset, _sheets_toolset, _docs_toolset
if name == 'bigquery_toolset':
if _bigquery_toolset is None:
from .google_api_toolsets import bigquery_toolset as bigquery
_bigquery_toolset = bigquery
return _bigquery_toolset
if name == 'calendar_toolset':
if _calendar_toolset is None:
from .google_api_toolsets import calendar_toolset as calendar
_calendar_toolset = calendar
return _calendar_toolset
if name == 'gmail_toolset':
if _gmail_toolset is None:
from .google_api_toolsets import gmail_toolset as gmail
_gmail_toolset = gmail
return _gmail_toolset
if name == 'youtube_toolset':
if _youtube_toolset is None:
from .google_api_toolsets import youtube_toolset as youtube
_youtube_toolset = youtube
return _youtube_toolset
if name == 'slides_toolset':
if _slides_toolset is None:
from .google_api_toolsets import slides_toolset as slides
_slides_toolset = slides
return _slides_toolset
if name == 'sheets_toolset':
if _sheets_toolset is None:
from .google_api_toolsets import sheets_toolset as sheets
_sheets_toolset = sheets
return _sheets_toolset
if name == 'docs_toolset':
if _docs_toolset is None:
from .google_api_toolsets import docs_toolset as docs
_docs_toolset = docs
return _docs_toolset

View File

@ -29,13 +29,19 @@ from ..tool_context import ToolContext
class GoogleApiTool(BaseTool): class GoogleApiTool(BaseTool):
def __init__(self, rest_api_tool: RestApiTool): def __init__(
self,
rest_api_tool: RestApiTool,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
):
super().__init__( super().__init__(
name=rest_api_tool.name, name=rest_api_tool.name,
description=rest_api_tool.description, description=rest_api_tool.description,
is_long_running=rest_api_tool.is_long_running, is_long_running=rest_api_tool.is_long_running,
) )
self._rest_api_tool = rest_api_tool self._rest_api_tool = rest_api_tool
self.configure_auth(client_id, client_secret)
@override @override
def _get_declaration(self) -> FunctionDeclaration: def _get_declaration(self) -> FunctionDeclaration:

View File

@ -36,22 +36,39 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
class GoogleApiToolset(BaseToolset): class GoogleApiToolset(BaseToolset):
"""Google API Toolset contains tools for interacting with Google APIs. """Google API Toolset contains tools for interacting with Google APIs.
Usually one toolsets will contains tools only replated to one Google API, e.g. Usually one toolsets will contains tools only related to one Google API, e.g.
Google Bigquery API toolset will contains tools only related to Google Google Bigquery API toolset will contains tools only related to Google
Bigquery API, like list dataset tool, list table tool etc. Bigquery API, like list dataset tool, list table tool etc.
""" """
def __init__( def __init__(
self, self,
openapi_toolset: OpenAPIToolset, api_name: str,
api_version: str,
client_id: Optional[str] = None, client_id: Optional[str] = None,
client_secret: Optional[str] = None, client_secret: Optional[str] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
): ):
self._openapi_toolset = openapi_toolset self.api_name = api_name
self.tool_filter = tool_filter self.api_version = api_version
self._client_id = client_id self._client_id = client_id
self._client_secret = client_secret self._client_secret = client_secret
self._openapi_toolset = self._load_toolset_with_oidc_auth()
self.tool_filter = tool_filter
def _is_tool_selected(
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
) -> bool:
if not self.tool_filter:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override @override
async def get_tools( async def get_tools(
@ -60,44 +77,26 @@ class GoogleApiToolset(BaseToolset):
"""Get all tools in the toolset.""" """Get all tools in the toolset."""
tools = [] tools = []
for tool in await self._openapi_toolset.get_tools(readonly_context): return [
if self.tool_filter and ( GoogleApiTool(tool, self._client_id, self._client_secret)
isinstance(self.tool_filter, ToolPredicate) for tool in await self._openapi_toolset.get_tools(readonly_context)
and not self.tool_filter(tool, readonly_context) if self._is_tool_selected(tool, readonly_context)
or isinstance(self.tool_filter, list) ]
and tool.name not in self.tool_filter
):
continue
google_api_tool = GoogleApiTool(tool)
google_api_tool.configure_auth(self._client_id, self._client_secret)
tools.append(google_api_tool)
return tools
def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]): def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]):
self.tool_filter = tool_filter self.tool_filter = tool_filter
@staticmethod def _load_toolset_with_oidc_auth(self) -> OpenAPIToolset:
def _load_toolset_with_oidc_auth( spec_dict = GoogleApiToOpenApiConverter(
spec_file: Optional[str] = None, self.api_name, self.api_version
spec_dict: Optional[dict[str, Any]] = None, ).convert()
scopes: Optional[list[str]] = None, scope = list(
) -> OpenAPIToolset: spec_dict['components']['securitySchemes']['oauth2']['flows'][
spec_str = None 'authorizationCode'
if spec_file: ]['scopes'].keys()
# Get the frame of the caller )[0]
caller_frame = inspect.stack()[1] return OpenAPIToolset(
# Get the filename of the caller
caller_filename = caller_frame.filename
# Get the directory of the caller
caller_dir = os.path.dirname(os.path.abspath(caller_filename))
# Join the directory path with the filename
yaml_path = os.path.join(caller_dir, spec_file)
with open(yaml_path, 'r', encoding='utf-8') as file:
spec_str = file.read()
toolset = OpenAPIToolset(
spec_dict=spec_dict, spec_dict=spec_dict,
spec_str=spec_str,
spec_str_type='yaml', spec_str_type='yaml',
auth_scheme=OpenIdConnectWithConfig( auth_scheme=OpenIdConnectWithConfig(
authorization_endpoint=( authorization_endpoint=(
@ -113,31 +112,14 @@ class GoogleApiToolset(BaseToolset):
'client_secret_basic', 'client_secret_basic',
], ],
grant_types_supported=['authorization_code'], grant_types_supported=['authorization_code'],
scopes=scopes, scopes=[scope],
), ),
) )
return toolset
def configure_auth(self, client_id: str, client_secret: str): def configure_auth(self, client_id: str, client_secret: str):
self._client_id = client_id self._client_id = client_id
self._client_secret = client_secret self._client_secret = client_secret
@classmethod
def load_toolset(
cls: Type[GoogleApiToolset],
api_name: str,
api_version: str,
) -> GoogleApiToolset:
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
scope = list(
spec_dict['components']['securitySchemes']['oauth2']['flows'][
'authorizationCode'
]['scopes'].keys()
)[0]
return cls(
cls._load_toolset_with_oidc_auth(spec_dict=spec_dict, scopes=[scope])
)
@override @override
async def close(self): async def close(self):
if self._openapi_toolset: if self._openapi_toolset:

View File

@ -14,98 +14,88 @@
import logging import logging
from typing import List, Optional, Union
from google.adk.tools.base_toolset import ToolPredicate
from .google_api_toolset import GoogleApiToolset from .google_api_toolset import GoogleApiToolset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_bigquery_toolset = None
_calendar_toolset = None class BigQueryToolset(GoogleApiToolset):
_gmail_toolset = None
_youtube_toolset = None def __init__(
_slides_toolset = None self,
_sheets_toolset = None client_id: str = None,
_docs_toolset = None client_secret: str = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
super().__init__("bigquery", "v2", client_id, client_secret, tool_filter)
def __getattr__(name): class CalendarToolset(GoogleApiToolset):
"""This method dynamically loads and returns GoogleApiToolSet instances for
various Google APIs. It uses a lazy loading approach, initializing each def __init__(
tool set only when it is first requested. This avoids unnecessary loading self,
of tool sets that are not used in a given session. client_id: str = None,
client_secret: str = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
super().__init__("calendar", "v3", client_id, client_secret, tool_filter)
Args:
name (str): The name of the tool set to retrieve (e.g.,
"bigquery_toolset").
Returns: class GmailToolset(GoogleApiToolset):
GoogleApiToolSet: The requested tool set instance.
Raises: def __init__(
AttributeError: If the requested tool set name is not recognized. self,
""" client_id: str = None,
global _bigquery_toolset, _calendar_toolset, _gmail_toolset, _youtube_toolset, _slides_toolset, _sheets_toolset, _docs_toolset client_secret: str = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
super().__init__("gmail", "v1", client_id, client_secret, tool_filter)
if name == "bigquery_toolset":
if _bigquery_toolset is None:
_bigquery_toolset = GoogleApiToolset.load_toolset(
api_name="bigquery",
api_version="v2",
)
return _bigquery_toolset class YoutubeToolset(GoogleApiToolset):
if name == "calendar_toolset": def __init__(
if _calendar_toolset is None: self,
_calendar_toolset = GoogleApiToolset.load_toolset( client_id: str = None,
api_name="calendar", client_secret: str = None,
api_version="v3", tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
) ):
super().__init__("youtube", "v3", client_id, client_secret, tool_filter)
return _calendar_toolset
if name == "gmail_toolset": class SlidesToolset(GoogleApiToolset):
if _gmail_toolset is None:
_gmail_toolset = GoogleApiToolset.load_toolset(
api_name="gmail",
api_version="v1",
)
return _gmail_toolset def __init__(
self,
client_id: str = None,
client_secret: str = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
super().__init__("slides", "v1", client_id, client_secret, tool_filter)
if name == "youtube_toolset":
if _youtube_toolset is None:
_youtube_toolset = GoogleApiToolset.load_toolset(
api_name="youtube",
api_version="v3",
)
return _youtube_toolset class SheetsToolset(GoogleApiToolset):
if name == "slides_toolset": def __init__(
if _slides_toolset is None: self,
_slides_toolset = GoogleApiToolset.load_toolset( client_id: str = None,
api_name="slides", client_secret: str = None,
api_version="v1", tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
) ):
super().__init__("sheets", "v4", client_id, client_secret, tool_filter)
return _slides_toolset
if name == "sheets_toolset": class DocsToolset(GoogleApiToolset):
if _sheets_toolset is None:
_sheets_toolset = GoogleApiToolset.load_toolset(
api_name="sheets",
api_version="v4",
)
return _sheets_toolset def __init__(
self,
if name == "docs_toolset": client_id: str = None,
if _docs_toolset is None: client_secret: str = None,
_docs_toolset = GoogleApiToolset.load_toolset( tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
api_name="docs", ):
api_version="v1", super().__init__("docs", "v1", client_id, client_secret, tool_filter)
)
return _docs_toolset