mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-25 22:47:44 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
19
src/google/adk/tools/apihub_tool/__init__.py
Normal file
19
src/google/adk/tools/apihub_tool/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .apihub_toolset import APIHubToolset
|
||||
|
||||
__all__ = [
|
||||
'APIHubToolset',
|
||||
]
|
||||
209
src/google/adk/tools/apihub_tool/apihub_toolset.py
Normal file
209
src/google/adk/tools/apihub_tool/apihub_toolset.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..openapi_tool.common.common import to_snake_case
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from .clients.apihub_client import APIHubClient
|
||||
|
||||
|
||||
class APIHubToolset:
|
||||
"""APIHubTool generates tools from a given API Hub resource.
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
API name, and can optionally include API version and spec name.
|
||||
- If apihub_resource_name includes a spec resource name, the content of that
|
||||
spec will be used for generating the tools.
|
||||
- If apihub_resource_name includes only an api or a version name, the
|
||||
first spec of the first version of that API will be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# Parameters for fetching API Hub resource
|
||||
apihub_resource_name: str,
|
||||
access_token: Optional[str] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
# Parameters for the toolset itself
|
||||
name: str = '',
|
||||
description: str = '',
|
||||
# Parameters for generating tools
|
||||
lazy_load_spec=False,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
# Optionally, you can provide a custom API Hub client
|
||||
apihub_client: Optional[APIHubClient] = None,
|
||||
):
|
||||
"""Initializes the APIHubTool with the given parameters.
|
||||
|
||||
Examples:
|
||||
```
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
API name, and can optionally include API version and spec name.
|
||||
- If apihub_resource_name includes a spec resource name, the content of that
|
||||
spec will be used for generating the tools.
|
||||
- If apihub_resource_name includes only an api or a version name, the
|
||||
first spec of the first version of that API will be used.
|
||||
|
||||
Example:
|
||||
* projects/xxx/locations/us-central1/apis/apiname/...
|
||||
* https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx
|
||||
|
||||
Args:
|
||||
apihub_resource_name: The resource name of the API in API Hub.
|
||||
Example: `projects/test-project/locations/us-central1/apis/test-api`.
|
||||
access_token: Google Access token. Generate with gcloud cli `gcloud auth
|
||||
auth print-access-token`. Used for fetching API Specs from API Hub.
|
||||
service_account_json: The service account config as a json string.
|
||||
Required if not using default service credential. It is used for
|
||||
creating the API Hub client and fetching the API Specs from API Hub.
|
||||
apihub_client: Optional custom API Hub client.
|
||||
name: Name of the toolset. Optional.
|
||||
description: Description of the toolset. Optional.
|
||||
auth_scheme: Auth scheme that applies to all the tool in the toolset.
|
||||
auth_credential: Auth credential that applies to all the tool in the
|
||||
toolset.
|
||||
lazy_load_spec: If True, the spec will be loaded lazily when needed.
|
||||
Otherwise, the spec will be loaded immediately and the tools will be
|
||||
generated during initialization.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.apihub_resource_name = apihub_resource_name
|
||||
self.lazy_load_spec = lazy_load_spec
|
||||
self.apihub_client = apihub_client or APIHubClient(
|
||||
access_token=access_token,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
if not self.lazy_load_spec:
|
||||
self._prepare_tools()
|
||||
|
||||
def get_tool(self, name: str) -> Optional[RestApiTool]:
|
||||
"""Retrieves a specific tool by its name.
|
||||
|
||||
Example:
|
||||
```
|
||||
apihub_tool = apihub_toolset.get_tool('my_tool')
|
||||
```
|
||||
|
||||
Args:
|
||||
name: The name of the tool to retrieve.
|
||||
|
||||
Returns:
|
||||
The tool with the given name, or None if no such tool exists.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return self.generated_tools[name] if name in self.generated_tools else None
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
"""Retrieves all available tools.
|
||||
|
||||
Returns:
|
||||
A list of all available RestApiTool objects.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return list(self.generated_tools.values())
|
||||
|
||||
def _are_tools_ready(self) -> bool:
|
||||
return not self.lazy_load_spec or self.generated_tools
|
||||
|
||||
def _prepare_tools(self) -> str:
|
||||
"""Fetches the spec from API Hub and generates the tools.
|
||||
|
||||
Returns:
|
||||
True if the tools are ready, False otherwise.
|
||||
"""
|
||||
# For each API, get the first version and the first spec of that version.
|
||||
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
|
||||
tools = self._parse_spec_to_tools(spec)
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
|
||||
"""Parses the spec string to a list of RestApiTool.
|
||||
|
||||
Args:
|
||||
spec_str: The spec string to parse.
|
||||
|
||||
Returns:
|
||||
A list of RestApiTool objects.
|
||||
"""
|
||||
spec_dict = yaml.safe_load(spec_str)
|
||||
if not spec_dict:
|
||||
return []
|
||||
|
||||
self.name = self.name or to_snake_case(
|
||||
spec_dict.get('info', {}).get('title', 'unnamed')
|
||||
)
|
||||
self.description = self.description or spec_dict.get('info', {}).get(
|
||||
'description', ''
|
||||
)
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=self.auth_credential,
|
||||
auth_scheme=self.auth_scheme,
|
||||
).get_tools()
|
||||
return tools
|
||||
13
src/google/adk/tools/apihub_tool/clients/__init__.py
Normal file
13
src/google/adk/tools/apihub_tool/clients/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
332
src/google/adk/tools/apihub_tool/clients/apihub_client.py
Normal file
332
src/google/adk/tools/apihub_tool/clients/apihub_client.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from google.auth import default as default_service_credential
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class BaseAPIHubClient(ABC):
|
||||
"""Base class for API Hub clients."""
|
||||
|
||||
@abstractmethod
|
||||
def get_spec_content(self, resource_name: str) -> str:
|
||||
"""From a given resource name, get the soec in the API Hub."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class APIHubClient(BaseAPIHubClient):
|
||||
"""Client for interacting with the API Hub service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
access_token: Optional[str] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the APIHubClient.
|
||||
|
||||
You must set either access_token or service_account_json. This
|
||||
credential is used for sending request to API Hub API.
|
||||
|
||||
Args:
|
||||
access_token: Google Access token. Generate with gcloud cli `gcloud auth
|
||||
print-access-token`. Useful for local testing.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential.
|
||||
"""
|
||||
self.root_url = "https://apihub.googleapis.com/v1"
|
||||
self.credential_cache = None
|
||||
self.access_token, self.service_account = None, None
|
||||
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
elif service_account_json:
|
||||
self.service_account = service_account_json
|
||||
|
||||
def get_spec_content(self, path: str) -> str:
|
||||
"""From a given path, get the first spec available in the API Hub.
|
||||
|
||||
- If path includes /apis/apiname, get the first spec of that API
|
||||
- If path includes /apis/apiname/versions/versionname, get the first spec
|
||||
of that API Version
|
||||
- If path includes /apis/apiname/versions/versionname/specs/specname, return
|
||||
that spec
|
||||
|
||||
Path can be resource name (projects/xxx/locations/us-central1/apis/apiname),
|
||||
and URL from the UI
|
||||
(https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx)
|
||||
|
||||
Args:
|
||||
path: The path to the API, API Version, or API Spec.
|
||||
|
||||
Returns:
|
||||
The content of the first spec available in the API Hub.
|
||||
"""
|
||||
apihub_resource_name, api_version_resource_name, api_spec_resource_name = (
|
||||
self._extract_resource_name(path)
|
||||
)
|
||||
|
||||
if apihub_resource_name and not api_version_resource_name:
|
||||
api = self.get_api(apihub_resource_name)
|
||||
versions = api.get("versions", [])
|
||||
if not versions:
|
||||
raise ValueError(
|
||||
f"No versions found in API Hub resource: {apihub_resource_name}"
|
||||
)
|
||||
api_version_resource_name = versions[0]
|
||||
|
||||
if api_version_resource_name and not api_spec_resource_name:
|
||||
api_version = self.get_api_version(api_version_resource_name)
|
||||
spec_resource_names = api_version.get("specs", [])
|
||||
if not spec_resource_names:
|
||||
raise ValueError(
|
||||
f"No specs found in API Hub version: {api_version_resource_name}"
|
||||
)
|
||||
api_spec_resource_name = spec_resource_names[0]
|
||||
|
||||
if api_spec_resource_name:
|
||||
spec_content = self._fetch_spec(api_spec_resource_name)
|
||||
return spec_content
|
||||
|
||||
raise ValueError("No API Hub resource found in path: {path}")
|
||||
|
||||
def list_apis(self, project: str, location: str) -> List[Dict[str, Any]]:
|
||||
"""Lists all APIs in the specified project and location.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project name.
|
||||
location: The location of the API Hub resources (e.g., 'us-central1').
|
||||
|
||||
Returns:
|
||||
A list of API dictionaries, or an empty list if an error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/projects/{project}/locations/{location}/apis"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
apis = response.json().get("apis", [])
|
||||
return apis
|
||||
|
||||
def get_api(self, api_resource_name: str) -> Dict[str, Any]:
|
||||
"""Get API detail by API name.
|
||||
|
||||
Args:
|
||||
api_resource_name: Resource name of this API, like
|
||||
projects/xxx/locations/us-central1/apis/apiname
|
||||
|
||||
Returns:
|
||||
An API and details in a dict.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_resource_name}"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
apis = response.json()
|
||||
return apis
|
||||
|
||||
def get_api_version(self, api_version_name: str) -> Dict[str, Any]:
|
||||
"""Gets details of a specific API version.
|
||||
|
||||
Args:
|
||||
api_version_name: The resource name of the API version.
|
||||
|
||||
Returns:
|
||||
The API version details as a dictionary, or an empty dictionary if an
|
||||
error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_version_name}"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _fetch_spec(self, api_spec_resource_name: str) -> str:
|
||||
"""Retrieves the content of a specific API specification.
|
||||
|
||||
Args:
|
||||
api_spec_resource_name: The resource name of the API spec.
|
||||
|
||||
Returns:
|
||||
The decoded content of the specification as a string, or an empty string
|
||||
if an error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_spec_resource_name}:contents"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
content_base64 = response.json().get("contents", "")
|
||||
if content_base64:
|
||||
content_decoded = base64.b64decode(content_base64).decode("utf-8")
|
||||
return content_decoded
|
||||
else:
|
||||
return ""
|
||||
|
||||
def _extract_resource_name(self, url_or_path: str) -> Tuple[str, str, str]:
|
||||
"""Extracts the resource names of an API, API Version, and API Spec from a given URL or path.
|
||||
|
||||
Args:
|
||||
url_or_path: The URL (UI or resource) or path string.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the resource names:
|
||||
{
|
||||
"api_resource_name": "projects/*/locations/*/apis/*",
|
||||
"api_version_resource_name":
|
||||
"projects/*/locations/*/apis/*/versions/*",
|
||||
"api_spec_resource_name":
|
||||
"projects/*/locations/*/apis/*/versions/*/specs/*"
|
||||
}
|
||||
or raises ValueError if extraction fails.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL or path is invalid or if required components
|
||||
(project, location, api) are missing.
|
||||
"""
|
||||
|
||||
query_params = None
|
||||
try:
|
||||
parsed_url = urlparse(url_or_path)
|
||||
path = parsed_url.path
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
# This is a path from UI. Remove unnecessary prefix.
|
||||
if "api-hub/" in path:
|
||||
path = path.split("api-hub")[1]
|
||||
except Exception:
|
||||
path = url_or_path
|
||||
|
||||
path_segments = [segment for segment in path.split("/") if segment]
|
||||
|
||||
project = None
|
||||
location = None
|
||||
api_id = None
|
||||
version_id = None
|
||||
spec_id = None
|
||||
|
||||
if "projects" in path_segments:
|
||||
project_index = path_segments.index("projects")
|
||||
if project_index + 1 < len(path_segments):
|
||||
project = path_segments[project_index + 1]
|
||||
elif query_params and "project" in query_params:
|
||||
project = query_params["project"][0]
|
||||
|
||||
if not project:
|
||||
raise ValueError(
|
||||
"Project ID not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/projects/PROJECT_ID' in the path or 'project=PROJECT_ID' query"
|
||||
" param in the input."
|
||||
)
|
||||
|
||||
if "locations" in path_segments:
|
||||
location_index = path_segments.index("locations")
|
||||
if location_index + 1 < len(path_segments):
|
||||
location = path_segments[location_index + 1]
|
||||
if not location:
|
||||
raise ValueError(
|
||||
"Location not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/location/LOCATION_ID' in the path."
|
||||
)
|
||||
|
||||
if "apis" in path_segments:
|
||||
api_index = path_segments.index("apis")
|
||||
if api_index + 1 < len(path_segments):
|
||||
api_id = path_segments[api_index + 1]
|
||||
if not api_id:
|
||||
raise ValueError(
|
||||
"API id not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/apis/API_ID' in the path."
|
||||
)
|
||||
if "versions" in path_segments:
|
||||
version_index = path_segments.index("versions")
|
||||
if version_index + 1 < len(path_segments):
|
||||
version_id = path_segments[version_index + 1]
|
||||
|
||||
if "specs" in path_segments:
|
||||
spec_index = path_segments.index("specs")
|
||||
if spec_index + 1 < len(path_segments):
|
||||
spec_id = path_segments[spec_index + 1]
|
||||
|
||||
api_resource_name = f"projects/{project}/locations/{location}/apis/{api_id}"
|
||||
api_version_resource_name = (
|
||||
f"{api_resource_name}/versions/{version_id}" if version_id else None
|
||||
)
|
||||
api_spec_resource_name = (
|
||||
f"{api_version_resource_name}/specs/{spec_id}"
|
||||
if version_id and spec_id
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
api_resource_name,
|
||||
api_version_resource_name,
|
||||
api_spec_resource_name,
|
||||
)
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.access_token:
|
||||
return self.access_token
|
||||
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account:
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid service account JSON: {e}") from e
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
115
src/google/adk/tools/apihub_tool/clients/secret_client.py
Normal file
115
src/google/adk/tools/apihub_tool/clients/secret_client.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
import google.auth.transport.requests
|
||||
from google.cloud import secretmanager
|
||||
from google.oauth2 import service_account
|
||||
|
||||
|
||||
class SecretManagerClient:
|
||||
"""A client for interacting with Google Cloud Secret Manager.
|
||||
|
||||
This class provides a simplified interface for retrieving secrets from
|
||||
Secret Manager, handling authentication using either a service account
|
||||
JSON keyfile (passed as a string) or a pre-existing authorization token.
|
||||
|
||||
Attributes:
|
||||
_credentials: Google Cloud credentials object (ServiceAccountCredentials
|
||||
or Credentials).
|
||||
_client: Secret Manager client instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_account_json: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the SecretManagerClient.
|
||||
|
||||
Args:
|
||||
service_account_json: The content of a service account JSON keyfile (as
|
||||
a string), not the file path. Must be valid JSON.
|
||||
auth_token: An existing Google Cloud authorization token.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `service_account_json` nor `auth_token` is
|
||||
provided,
|
||||
or if both are provided. Also raised if the service_account_json
|
||||
is not valid JSON.
|
||||
google.auth.exceptions.GoogleAuthError: If authentication fails.
|
||||
"""
|
||||
if service_account_json:
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(service_account_json)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid service account JSON: {e}") from e
|
||||
elif auth_token:
|
||||
credentials = google.auth.credentials.Credentials(
|
||||
token=auth_token,
|
||||
refresh_token=None,
|
||||
token_uri=None,
|
||||
client_id=None,
|
||||
client_secret=None,
|
||||
)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"'service_account_json' or 'auth_token' are both missing, and"
|
||||
f" error occurred while trying to use default credentials: {e}"
|
||||
) from e
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Must provide either 'service_account_json' or 'auth_token', not both"
|
||||
" or neither."
|
||||
)
|
||||
|
||||
self._credentials = credentials
|
||||
self._client = secretmanager.SecretManagerServiceClient(
|
||||
credentials=self._credentials
|
||||
)
|
||||
|
||||
def get_secret(self, resource_name: str) -> str:
|
||||
"""Retrieves a secret from Google Cloud Secret Manager.
|
||||
|
||||
Args:
|
||||
resource_name: The full resource name of the secret, in the format
|
||||
"projects/*/secrets/*/versions/*". Usually you want the "latest"
|
||||
version, e.g.,
|
||||
"projects/my-project/secrets/my-secret/versions/latest".
|
||||
|
||||
Returns:
|
||||
The secret payload as a string.
|
||||
|
||||
Raises:
|
||||
google.api_core.exceptions.GoogleAPIError: If the Secret Manager API
|
||||
returns an error (e.g., secret not found, permission denied).
|
||||
Exception: For other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
response = self._client.access_secret_version(name=resource_name)
|
||||
return response.payload.data.decode("UTF-8")
|
||||
except Exception as e:
|
||||
raise e # Re-raise the exception to allow for handling by the caller
|
||||
# Consider logging the exception here before re-raising.
|
||||
Reference in New Issue
Block a user