refactor: refactor google api toolset to hide non-public field

PiperOrigin-RevId: 758469484
This commit is contained in:
Xiang (Sean) Zhou 2025-05-13 19:25:32 -07:00 committed by Copybara-Service
parent 30947b48b8
commit 00e0035c8a
4 changed files with 69 additions and 65 deletions

View File

@ -35,22 +35,22 @@ class GoogleApiTool(BaseTool):
description=rest_api_tool.description, description=rest_api_tool.description,
is_long_running=rest_api_tool.is_long_running, is_long_running=rest_api_tool.is_long_running,
) )
self.rest_api_tool = rest_api_tool self._rest_api_tool = rest_api_tool
@override @override
def _get_declaration(self) -> FunctionDeclaration: def _get_declaration(self) -> FunctionDeclaration:
return self.rest_api_tool._get_declaration() return self._rest_api_tool._get_declaration()
@override @override
async def run_async( async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext] self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return await self.rest_api_tool.run_async( return await self._rest_api_tool.run_async(
args=args, tool_context=tool_context args=args, tool_context=tool_context
) )
def configure_auth(self, client_id: str, client_secret: str): def configure_auth(self, client_id: str, client_secret: str):
self.rest_api_tool.auth_credential = AuthCredential( self._rest_api_tool.auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth( oauth2=OAuth2Auth(
client_id=client_id, client_id=client_id,

View File

@ -48,10 +48,10 @@ class GoogleApiToolset(BaseToolset):
client_secret: Optional[str] = None, client_secret: Optional[str] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
): ):
self.openapi_toolset = openapi_toolset self._openapi_toolset = openapi_toolset
self.tool_filter = tool_filter self.tool_filter = tool_filter
self.client_id = client_id self._client_id = client_id
self.client_secret = client_secret self._client_secret = client_secret
@override @override
async def get_tools( async def get_tools(
@ -60,7 +60,7 @@ class GoogleApiToolset(BaseToolset):
"""Get all tools in the toolset.""" """Get all tools in the toolset."""
tools = [] tools = []
for tool in await self.openapi_toolset.get_tools(readonly_context): for tool in await self._openapi_toolset.get_tools(readonly_context):
if self.tool_filter and ( if self.tool_filter and (
isinstance(self.tool_filter, ToolPredicate) isinstance(self.tool_filter, ToolPredicate)
and not self.tool_filter(tool, readonly_context) and not self.tool_filter(tool, readonly_context)
@ -69,7 +69,7 @@ class GoogleApiToolset(BaseToolset):
): ):
continue continue
google_api_tool = GoogleApiTool(tool) google_api_tool = GoogleApiTool(tool)
google_api_tool.configure_auth(self.client_id, self.client_secret) google_api_tool.configure_auth(self._client_id, self._client_secret)
tools.append(google_api_tool) tools.append(google_api_tool)
return tools return tools
@ -119,8 +119,8 @@ class GoogleApiToolset(BaseToolset):
return toolset return toolset
def configure_auth(self, client_id: str, client_secret: str): def configure_auth(self, client_id: str, client_secret: str):
self.client_id = client_id self._client_id = client_id
self.client_secret = client_secret self._client_secret = client_secret
@classmethod @classmethod
def load_toolset( def load_toolset(
@ -140,5 +140,5 @@ class GoogleApiToolset(BaseToolset):
@override @override
async def close(self): async def close(self):
if self.openapi_toolset: if self._openapi_toolset:
await self.openapi_toolset.close() await self._openapi_toolset.close()

View File

@ -37,11 +37,11 @@ class GoogleApiToOpenApiConverter:
api_name: The name of the Google API (e.g., "calendar") api_name: The name of the Google API (e.g., "calendar")
api_version: The version of the API (e.g., "v3") api_version: The version of the API (e.g., "v3")
""" """
self.api_name = api_name self._api_name = api_name
self.api_version = api_version self._api_version = api_version
self.google_api_resource = None self._google_api_resource = None
self.google_api_spec = None self._google_api_spec = None
self.openapi_spec = { self._openapi_spec = {
"openapi": "3.0.0", "openapi": "3.0.0",
"info": {}, "info": {},
"servers": [], "servers": [],
@ -53,18 +53,20 @@ class GoogleApiToOpenApiConverter:
"""Fetches the Google API specification using discovery service.""" """Fetches the Google API specification using discovery service."""
try: try:
logger.info( logger.info(
"Fetching Google API spec for %s %s", self.api_name, self.api_version "Fetching Google API spec for %s %s",
self._api_name,
self._api_version,
) )
# Build a resource object for the specified API # Build a resource object for the specified API
self.google_api_resource = build(self.api_name, self.api_version) self._google_api_resource = build(self._api_name, self._api_version)
# Access the underlying API discovery document # Access the underlying API discovery document
self.google_api_spec = self.google_api_resource._rootDesc self._google_api_spec = self._google_api_resource._rootDesc
if not self.google_api_spec: if not self._google_api_spec:
raise ValueError("Failed to retrieve API specification") raise ValueError("Failed to retrieve API specification")
logger.info("Successfully fetched %s API specification", self.api_name) logger.info("Successfully fetched %s API specification", self._api_name)
except HttpError as e: except HttpError as e:
logger.error("HTTP Error: %s", e) logger.error("HTTP Error: %s", e)
raise raise
@ -78,7 +80,7 @@ class GoogleApiToOpenApiConverter:
Returns: Returns:
Dict containing the converted OpenAPI v3 specification Dict containing the converted OpenAPI v3 specification
""" """
if not self.google_api_spec: if not self._google_api_spec:
self.fetch_google_api_spec() self.fetch_google_api_spec()
# Convert basic API information # Convert basic API information
@ -94,49 +96,49 @@ class GoogleApiToOpenApiConverter:
self._convert_schemas() self._convert_schemas()
# Convert endpoints/paths # Convert endpoints/paths
self._convert_resources(self.google_api_spec.get("resources", {})) self._convert_resources(self._google_api_spec.get("resources", {}))
# Convert top-level methods, if any # Convert top-level methods, if any
self._convert_methods(self.google_api_spec.get("methods", {}), "/") self._convert_methods(self._google_api_spec.get("methods", {}), "/")
return self.openapi_spec return self._openapi_spec
def _convert_info(self) -> None: def _convert_info(self) -> None:
"""Convert basic API information.""" """Convert basic API information."""
self.openapi_spec["info"] = { self._openapi_spec["info"] = {
"title": self.google_api_spec.get("title", f"{self.api_name} API"), "title": self._google_api_spec.get("title", f"{self._api_name} API"),
"description": self.google_api_spec.get("description", ""), "description": self._google_api_spec.get("description", ""),
"version": self.google_api_spec.get("version", self.api_version), "version": self._google_api_spec.get("version", self._api_version),
"contact": {}, "contact": {},
"termsOfService": self.google_api_spec.get("documentationLink", ""), "termsOfService": self._google_api_spec.get("documentationLink", ""),
} }
# Add documentation links if available # Add documentation links if available
docs_link = self.google_api_spec.get("documentationLink") docs_link = self._google_api_spec.get("documentationLink")
if docs_link: if docs_link:
self.openapi_spec["externalDocs"] = { self._openapi_spec["externalDocs"] = {
"description": "API Documentation", "description": "API Documentation",
"url": docs_link, "url": docs_link,
} }
def _convert_servers(self) -> None: def _convert_servers(self) -> None:
"""Convert server information.""" """Convert server information."""
base_url = self.google_api_spec.get( base_url = self._google_api_spec.get(
"rootUrl", "" "rootUrl", ""
) + self.google_api_spec.get("servicePath", "") ) + self._google_api_spec.get("servicePath", "")
# Remove trailing slash if present # Remove trailing slash if present
if base_url.endswith("/"): if base_url.endswith("/"):
base_url = base_url[:-1] base_url = base_url[:-1]
self.openapi_spec["servers"] = [{ self._openapi_spec["servers"] = [{
"url": base_url, "url": base_url,
"description": f"{self.api_name} {self.api_version} API", "description": f"{self._api_name} {self._api_version} API",
}] }]
def _convert_security_schemes(self) -> None: def _convert_security_schemes(self) -> None:
"""Convert authentication and authorization schemes.""" """Convert authentication and authorization schemes."""
auth = self.google_api_spec.get("auth", {}) auth = self._google_api_spec.get("auth", {})
oauth2 = auth.get("oauth2", {}) oauth2 = auth.get("oauth2", {})
if oauth2: if oauth2:
@ -147,7 +149,7 @@ class GoogleApiToOpenApiConverter:
for scope, scope_info in scopes.items(): for scope, scope_info in scopes.items():
formatted_scopes[scope] = scope_info.get("description", "") formatted_scopes[scope] = scope_info.get("description", "")
self.openapi_spec["components"]["securitySchemes"]["oauth2"] = { self._openapi_spec["components"]["securitySchemes"]["oauth2"] = {
"type": "oauth2", "type": "oauth2",
"description": "OAuth 2.0 authentication", "description": "OAuth 2.0 authentication",
"flows": { "flows": {
@ -162,7 +164,7 @@ class GoogleApiToOpenApiConverter:
} }
# Add API key authentication (most Google APIs support this) # Add API key authentication (most Google APIs support this)
self.openapi_spec["components"]["securitySchemes"]["apiKey"] = { self._openapi_spec["components"]["securitySchemes"]["apiKey"] = {
"type": "apiKey", "type": "apiKey",
"in": "query", "in": "query",
"name": "key", "name": "key",
@ -170,18 +172,20 @@ class GoogleApiToOpenApiConverter:
} }
# Create global security requirement # Create global security requirement
self.openapi_spec["security"] = [ self._openapi_spec["security"] = [
{"oauth2": list(formatted_scopes.keys())} if oauth2 else {}, {"oauth2": list(formatted_scopes.keys())} if oauth2 else {},
{"apiKey": []}, {"apiKey": []},
] ]
def _convert_schemas(self) -> None: def _convert_schemas(self) -> None:
"""Convert schema definitions (models).""" """Convert schema definitions (models)."""
schemas = self.google_api_spec.get("schemas", {}) schemas = self._google_api_spec.get("schemas", {})
for schema_name, schema_def in schemas.items(): for schema_name, schema_def in schemas.items():
converted_schema = self._convert_schema_object(schema_def) converted_schema = self._convert_schema_object(schema_def)
self.openapi_spec["components"]["schemas"][schema_name] = converted_schema self._openapi_spec["components"]["schemas"][
schema_name
] = converted_schema
def _convert_schema_object( def _convert_schema_object(
self, schema_def: Dict[str, Any] self, schema_def: Dict[str, Any]
@ -314,11 +318,11 @@ class GoogleApiToOpenApiConverter:
path_params = self._extract_path_parameters(rest_path) path_params = self._extract_path_parameters(rest_path)
# Create path entry if it doesn't exist # Create path entry if it doesn't exist
if rest_path not in self.openapi_spec["paths"]: if rest_path not in self._openapi_spec["paths"]:
self.openapi_spec["paths"][rest_path] = {} self._openapi_spec["paths"][rest_path] = {}
# Add the operation for this method # Add the operation for this method
self.openapi_spec["paths"][rest_path][http_method] = ( self._openapi_spec["paths"][rest_path][http_method] = (
self._convert_operation(method_data, path_params) self._convert_operation(method_data, path_params)
) )
@ -472,7 +476,7 @@ class GoogleApiToOpenApiConverter:
output_path: Path where the OpenAPI spec should be saved output_path: Path where the OpenAPI spec should be saved
""" """
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(self.openapi_spec, f, indent=2) json.dump(self._openapi_spec, f, indent=2)
logger.info("OpenAPI specification saved to %s", output_path) logger.info("OpenAPI specification saved to %s", output_path)

View File

@ -214,7 +214,7 @@ def mock_api_resource(calendar_api_spec):
@pytest.fixture @pytest.fixture
def prepared_converter(converter, calendar_api_spec): def prepared_converter(converter, calendar_api_spec):
"""Fixture that provides a converter with the API spec already set.""" """Fixture that provides a converter with the API spec already set."""
converter.google_api_spec = calendar_api_spec converter._google_api_spec = calendar_api_spec
return converter return converter
@ -242,14 +242,14 @@ class TestGoogleApiToOpenApiConverter:
def test_init(self, converter): def test_init(self, converter):
"""Test converter initialization.""" """Test converter initialization."""
assert converter.api_name == "calendar" assert converter._api_name == "calendar"
assert converter.api_version == "v3" assert converter._api_version == "v3"
assert converter.google_api_resource is None assert converter._google_api_resource is None
assert converter.google_api_spec is None assert converter._google_api_spec is None
assert converter.openapi_spec["openapi"] == "3.0.0" assert converter._openapi_spec["openapi"] == "3.0.0"
assert "info" in converter.openapi_spec assert "info" in converter._openapi_spec
assert "paths" in converter.openapi_spec assert "paths" in converter._openapi_spec
assert "components" in converter.openapi_spec assert "components" in converter._openapi_spec
def test_fetch_google_api_spec( def test_fetch_google_api_spec(
self, converter_with_patched_build, calendar_api_spec self, converter_with_patched_build, calendar_api_spec
@ -259,7 +259,7 @@ class TestGoogleApiToOpenApiConverter:
converter_with_patched_build.fetch_google_api_spec() converter_with_patched_build.fetch_google_api_spec()
# Verify the results # Verify the results
assert converter_with_patched_build.google_api_spec == calendar_api_spec assert converter_with_patched_build._google_api_spec == calendar_api_spec
def test_fetch_google_api_spec_error(self, monkeypatch, converter): def test_fetch_google_api_spec_error(self, monkeypatch, converter):
"""Test error handling when fetching Google API specification.""" """Test error handling when fetching Google API specification."""
@ -282,14 +282,14 @@ class TestGoogleApiToOpenApiConverter:
prepared_converter._convert_info() prepared_converter._convert_info()
# Verify the results # Verify the results
info = prepared_converter.openapi_spec["info"] info = prepared_converter._openapi_spec["info"]
assert info["title"] == "Google Calendar API" assert info["title"] == "Google Calendar API"
assert info["description"] == "Accesses the Google Calendar API" assert info["description"] == "Accesses the Google Calendar API"
assert info["version"] == "v3" assert info["version"] == "v3"
assert info["termsOfService"] == "https://developers.google.com/calendar/" assert info["termsOfService"] == "https://developers.google.com/calendar/"
# Check external docs # Check external docs
external_docs = prepared_converter.openapi_spec["externalDocs"] external_docs = prepared_converter._openapi_spec["externalDocs"]
assert external_docs["url"] == "https://developers.google.com/calendar/" assert external_docs["url"] == "https://developers.google.com/calendar/"
def test_convert_servers(self, prepared_converter): def test_convert_servers(self, prepared_converter):
@ -298,7 +298,7 @@ class TestGoogleApiToOpenApiConverter:
prepared_converter._convert_servers() prepared_converter._convert_servers()
# Verify the results # Verify the results
servers = prepared_converter.openapi_spec["servers"] servers = prepared_converter._openapi_spec["servers"]
assert len(servers) == 1 assert len(servers) == 1
assert servers[0]["url"] == "https://www.googleapis.com/calendar/v3" assert servers[0]["url"] == "https://www.googleapis.com/calendar/v3"
assert servers[0]["description"] == "calendar v3 API" assert servers[0]["description"] == "calendar v3 API"
@ -309,7 +309,7 @@ class TestGoogleApiToOpenApiConverter:
prepared_converter._convert_security_schemes() prepared_converter._convert_security_schemes()
# Verify the results # Verify the results
security_schemes = prepared_converter.openapi_spec["components"][ security_schemes = prepared_converter._openapi_spec["components"][
"securitySchemes" "securitySchemes"
] ]
@ -335,7 +335,7 @@ class TestGoogleApiToOpenApiConverter:
prepared_converter._convert_schemas() prepared_converter._convert_schemas()
# Verify the results # Verify the results
schemas = prepared_converter.openapi_spec["components"]["schemas"] schemas = prepared_converter._openapi_spec["components"]["schemas"]
# Check Calendar schema # Check Calendar schema
assert "Calendar" in schemas assert "Calendar" in schemas
@ -524,7 +524,7 @@ class TestGoogleApiToOpenApiConverter:
prepared_converter._convert_methods(methods, "/calendars") prepared_converter._convert_methods(methods, "/calendars")
# Verify the results # Verify the results
paths = prepared_converter.openapi_spec["paths"] paths = prepared_converter._openapi_spec["paths"]
# Check GET method # Check GET method
assert "/calendars/{calendarId}" in paths assert "/calendars/{calendarId}" in paths
@ -565,7 +565,7 @@ class TestGoogleApiToOpenApiConverter:
prepared_converter._convert_resources(resources) prepared_converter._convert_resources(resources)
# Verify the results # Verify the results
paths = prepared_converter.openapi_spec["paths"] paths = prepared_converter._openapi_spec["paths"]
# Check top-level resource methods # Check top-level resource methods
assert "/calendars/{calendarId}" in paths assert "/calendars/{calendarId}" in paths