mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-22 13:22:19 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
21
src/google/adk/tools/openapi_tool/__init__.py
Normal file
21
src/google/adk/tools/openapi_tool/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenAPIToolset
|
||||
from .openapi_spec_parser import RestApiTool
|
||||
|
||||
__all__ = [
|
||||
'OpenAPIToolset',
|
||||
'RestApiTool',
|
||||
]
|
||||
19
src/google/adk/tools/openapi_tool/auth/__init__.py
Normal file
19
src/google/adk/tools/openapi_tool/auth/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import auth_helpers
|
||||
|
||||
__all__ = [
|
||||
'auth_helpers',
|
||||
]
|
||||
498
src/google/adk/tools/openapi_tool/auth/auth_helpers.py
Normal file
498
src/google/adk/tools/openapi_tool/auth/auth_helpers.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OpenIdConnect
|
||||
from fastapi.openapi.models import Schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
import requests
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_credential import AuthCredentialTypes
|
||||
from ....auth.auth_credential import HttpAuth
|
||||
from ....auth.auth_credential import HttpCredentials
|
||||
from ....auth.auth_credential import OAuth2Auth
|
||||
from ....auth.auth_credential import ServiceAccount
|
||||
from ....auth.auth_credential import ServiceAccountCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....auth.auth_schemes import AuthSchemeType
|
||||
from ....auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from ..common.common import ApiParameter
|
||||
|
||||
|
||||
class OpenIdConfig(BaseModel):
|
||||
"""Represents OpenID Connect configuration.
|
||||
|
||||
Attributes:
|
||||
client_id: The client ID.
|
||||
auth_uri: The authorization URI.
|
||||
token_uri: The token URI.
|
||||
client_secret: The client secret.
|
||||
|
||||
Example:
|
||||
config = OpenIdConfig(
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_secret="your_client_secret",
|
||||
redirect
|
||||
)
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
auth_uri: str
|
||||
token_uri: str
|
||||
client_secret: str
|
||||
redirect_uri: Optional[str]
|
||||
|
||||
|
||||
def token_to_scheme_credential(
|
||||
token_type: Literal["apikey", "oauth2Token"],
|
||||
location: Optional[Literal["header", "query", "cookie"]] = None,
|
||||
name: Optional[str] = None,
|
||||
credential_value: Optional[str] = None,
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates a AuthScheme and AuthCredential for API key or bearer token.
|
||||
|
||||
Examples:
|
||||
```
|
||||
# API Key in header
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "header",
|
||||
"X-API-Key", "your_api_key_value")
|
||||
|
||||
# API Key in query parameter
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "query",
|
||||
"api_key", "your_api_key_value")
|
||||
|
||||
# OAuth2 Bearer Token in Authorization header
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("oauth2Token",
|
||||
"header", "Authorization", "your_bearer_token_value")
|
||||
```
|
||||
|
||||
Args:
|
||||
type: 'apikey' or 'oauth2Token'.
|
||||
location: 'header', 'query', or 'cookie' (only 'header' for oauth2Token).
|
||||
name: The name of the header, query parameter, or cookie.
|
||||
credential_value: The value of the API Key/ Token.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid type or location.
|
||||
"""
|
||||
if token_type == "apikey":
|
||||
in_: APIKeyIn
|
||||
if location == "header":
|
||||
in_ = APIKeyIn.header
|
||||
elif location == "query":
|
||||
in_ = APIKeyIn.query
|
||||
elif location == "cookie":
|
||||
in_ = APIKeyIn.cookie
|
||||
else:
|
||||
raise ValueError(f"Invalid location for apiKey: {location}")
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": in_,
|
||||
"name": name,
|
||||
})
|
||||
if credential_value:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key=credential_value
|
||||
)
|
||||
else:
|
||||
auth_credential = None
|
||||
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
elif token_type == "oauth2Token":
|
||||
# ignore location. OAuth2 Bearer Token is always in Authorization header.
|
||||
auth_scheme = HTTPBearer(
|
||||
bearerFormat="JWT"
|
||||
) # Common format, can be omitted.
|
||||
if credential_value:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credential_value),
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_credential = None
|
||||
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid security scheme type: {type}")
|
||||
|
||||
|
||||
def service_account_dict_to_scheme_credential(
|
||||
config: Dict[str, Any],
|
||||
scopes: List[str],
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates AuthScheme and AuthCredential for Google Service Account.
|
||||
|
||||
Returns a bearer token scheme, and a service account credential.
|
||||
|
||||
Args:
|
||||
config: A ServiceAccount object containing the Google Service Account
|
||||
configuration.
|
||||
scopes: A list of scopes to be used.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
"""
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
service_account = ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential.model_construct(
|
||||
**config
|
||||
),
|
||||
scopes=scopes,
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=service_account,
|
||||
)
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
|
||||
def service_account_scheme_credential(
|
||||
config: ServiceAccount,
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates AuthScheme and AuthCredential for Google Service Account.
|
||||
|
||||
Returns a bearer token scheme, and a service account credential.
|
||||
|
||||
Args:
|
||||
config: A ServiceAccount object containing the Google Service Account
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
"""
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=config
|
||||
)
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
|
||||
def openid_dict_to_scheme_credential(
|
||||
config_dict: Dict[str, Any],
|
||||
scopes: List[str],
|
||||
credential_dict: Dict[str, Any],
|
||||
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
|
||||
"""Constructs OpenID scheme and credential from configuration and credential dictionaries.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary containing OpenID Connect configuration, must
|
||||
include at least 'authorization_endpoint' and 'token_endpoint'.
|
||||
scopes: List of scopes to be used.
|
||||
credential_dict: Dictionary containing credential information, must
|
||||
include 'client_id', 'client_secret', and 'scopes'. May optionally
|
||||
include 'redirect_uri'.
|
||||
|
||||
Returns:
|
||||
Tuple: (OpenIdConnectWithConfig, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing in the input dictionaries.
|
||||
"""
|
||||
|
||||
# Validate and create the OpenIdConnectWithConfig scheme
|
||||
try:
|
||||
config_dict["scopes"] = scopes
|
||||
# If user provides the OpenID Config as a static dict, it may not contain
|
||||
# openIdConnect URL.
|
||||
if "openIdConnectUrl" not in config_dict:
|
||||
config_dict["openIdConnectUrl"] = ""
|
||||
openid_scheme = OpenIdConnectWithConfig.model_validate(config_dict)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid OpenID Connect configuration: {e}") from e
|
||||
|
||||
# Attempt to adjust credential_dict if this is a key downloaded from Google
|
||||
# OAuth config
|
||||
if len(list(credential_dict.values())) == 1:
|
||||
credential_value = list(credential_dict.values())[0]
|
||||
if "client_id" in credential_value and "client_secret" in credential_value:
|
||||
credential_dict = credential_value
|
||||
|
||||
# Validate credential_dict
|
||||
required_credential_fields = ["client_id", "client_secret"]
|
||||
missing_fields = [
|
||||
field
|
||||
for field in required_credential_fields
|
||||
if field not in credential_dict
|
||||
]
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
"Missing required fields in credential_dict:"
|
||||
f" {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
# Construct AuthCredential
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=credential_dict["client_id"],
|
||||
client_secret=credential_dict["client_secret"],
|
||||
redirect_uri=credential_dict.get("redirect_uri", None),
|
||||
),
|
||||
)
|
||||
|
||||
return openid_scheme, auth_credential
|
||||
|
||||
|
||||
def openid_url_to_scheme_credential(
|
||||
openid_url: str, scopes: List[str], credential_dict: Dict[str, Any]
|
||||
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
|
||||
"""Constructs OpenID scheme and credential from OpenID URL, scopes, and credential dictionary.
|
||||
|
||||
Fetches OpenID configuration from the provided URL.
|
||||
|
||||
Args:
|
||||
openid_url: The OpenID Connect discovery URL.
|
||||
scopes: List of scopes to be used.
|
||||
credential_dict: Dictionary containing credential information, must
|
||||
include at least "client_id" and "client_secret", may optionally include
|
||||
"redirect_uri" and "scope"
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: If the OpenID URL is invalid, fetching fails, or required
|
||||
fields are missing.
|
||||
requests.exceptions.RequestException: If there's an error during the
|
||||
HTTP request.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(openid_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
config_dict = response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Failed to fetch OpenID configuration from {openid_url}: {e}"
|
||||
) from e
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Invalid JSON response from OpenID configuration endpoint"
|
||||
f" {openid_url}: {e}"
|
||||
) from e
|
||||
|
||||
# Add openIdConnectUrl to config dict
|
||||
config_dict["openIdConnectUrl"] = openid_url
|
||||
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
INTERNAL_AUTH_PREFIX = "_auth_prefix_vaf_"
|
||||
|
||||
|
||||
def credential_to_param(
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: AuthCredential,
|
||||
) -> Tuple[Optional[ApiParameter], Optional[Dict[str, Any]]]:
|
||||
"""Converts AuthCredential and AuthScheme to a Parameter and a dictionary for additional kwargs.
|
||||
|
||||
This function now supports all credential types returned by the exchangers:
|
||||
- API Key
|
||||
- HTTP Bearer (for Bearer tokens, OAuth2, Service Account, OpenID Connect)
|
||||
- OAuth2 and OpenID Connect (returns None, None, as the token is now a Bearer
|
||||
token)
|
||||
- Service Account (returns None, None, as the token is now a Bearer token)
|
||||
|
||||
Args:
|
||||
auth_scheme: The AuthScheme object.
|
||||
auth_credential: The AuthCredential object.
|
||||
|
||||
Returns:
|
||||
Tuple: (ApiParameter, Dict[str, Any])
|
||||
"""
|
||||
if not auth_credential:
|
||||
return None, None
|
||||
|
||||
if (
|
||||
auth_scheme.type_ == AuthSchemeType.apiKey
|
||||
and auth_credential
|
||||
and auth_credential.api_key
|
||||
):
|
||||
param_name = auth_scheme.name or ""
|
||||
python_name = INTERNAL_AUTH_PREFIX + param_name
|
||||
if auth_scheme.in_ == APIKeyIn.header:
|
||||
param_location = "header"
|
||||
elif auth_scheme.in_ == APIKeyIn.query:
|
||||
param_location = "query"
|
||||
elif auth_scheme.in_ == APIKeyIn.cookie:
|
||||
param_location = "cookie"
|
||||
else:
|
||||
raise ValueError(f"Invalid API Key location: {auth_scheme.in_}")
|
||||
|
||||
param = ApiParameter(
|
||||
original_name=param_name,
|
||||
param_location=param_location,
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "",
|
||||
py_name=python_name,
|
||||
)
|
||||
kwargs = {param.py_name: auth_credential.api_key}
|
||||
return param, kwargs
|
||||
|
||||
# TODO(cheliu): Split handling for OpenIDConnect scheme and native HTTPBearer
|
||||
# Scheme
|
||||
elif (
|
||||
auth_credential and auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
):
|
||||
if (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name="Authorization",
|
||||
param_location="header",
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "Bearer token",
|
||||
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
|
||||
)
|
||||
kwargs = {
|
||||
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
|
||||
}
|
||||
return param, kwargs
|
||||
elif (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and (
|
||||
auth_credential.http.credentials.username
|
||||
or auth_credential.http.credentials.password
|
||||
)
|
||||
):
|
||||
# Basic Auth is explicitly NOT supported
|
||||
raise NotImplementedError("Basic Authentication is not supported.")
|
||||
else:
|
||||
raise ValueError("Invalid HTTP auth credentials")
|
||||
|
||||
# Service Account tokens, OAuth2 Tokens and OpenID Tokens are now handled as
|
||||
# Bearer tokens.
|
||||
elif (auth_scheme.type_ == AuthSchemeType.oauth2 and auth_credential) or (
|
||||
auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_credential
|
||||
):
|
||||
if (
|
||||
auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name="Authorization",
|
||||
param_location="header",
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "Bearer token",
|
||||
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
|
||||
)
|
||||
kwargs = {
|
||||
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
|
||||
}
|
||||
return param, kwargs
|
||||
return None, None
|
||||
else:
|
||||
raise ValueError("Invalid security scheme and credential combination")
|
||||
|
||||
|
||||
def dict_to_auth_scheme(data: Dict[str, Any]) -> AuthScheme:
|
||||
"""Converts a dictionary to a FastAPI AuthScheme object.
|
||||
|
||||
Args:
|
||||
data: The dictionary representing the security scheme.
|
||||
|
||||
Returns:
|
||||
A AuthScheme object (APIKey, HTTPBase, OAuth2, OpenIdConnect, or
|
||||
HTTPBearer).
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'type' field is missing or invalid, or if the
|
||||
dictionary cannot be converted to the corresponding Pydantic model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
api_key_data = {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-API-Key",
|
||||
}
|
||||
api_key_scheme = dict_to_auth_scheme(api_key_data)
|
||||
|
||||
bearer_data = {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
}
|
||||
bearer_scheme = dict_to_auth_scheme(bearer_data)
|
||||
|
||||
|
||||
oauth2_data = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://example.com/auth",
|
||||
"tokenUrl": "https://example.com/token",
|
||||
}
|
||||
}
|
||||
}
|
||||
oauth2_scheme = dict_to_auth_scheme(oauth2_data)
|
||||
|
||||
openid_data = {
|
||||
"type": "openIdConnect",
|
||||
"openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
|
||||
}
|
||||
openid_scheme = dict_to_auth_scheme(openid_data)
|
||||
|
||||
|
||||
```
|
||||
"""
|
||||
if "type" not in data:
|
||||
raise ValueError("Missing 'type' field in security scheme dictionary.")
|
||||
|
||||
security_type = data["type"]
|
||||
try:
|
||||
if security_type == "apiKey":
|
||||
return APIKey.model_validate(data)
|
||||
elif security_type == "http":
|
||||
if data.get("scheme") == "bearer":
|
||||
return HTTPBearer.model_validate(data)
|
||||
else:
|
||||
return HTTPBase.model_validate(data) # Generic HTTP
|
||||
elif security_type == "oauth2":
|
||||
return OAuth2.model_validate(data)
|
||||
elif security_type == "openIdConnect":
|
||||
return OpenIdConnect.model_validate(data)
|
||||
else:
|
||||
raise ValueError(f"Invalid security scheme type: {security_type}")
|
||||
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid security scheme data: {e}") from e
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from .oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from .service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
|
||||
__all__ = [
|
||||
'AutoAuthCredentialExchanger',
|
||||
'BaseAuthCredentialExchanger',
|
||||
'OAuth2CredentialExchanger',
|
||||
'ServiceAccountCredentialExchanger',
|
||||
]
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
|
||||
from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from .oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from .service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
|
||||
|
||||
class AutoAuthCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Automatically selects the appropriate credential exchanger based on the auth scheme.
|
||||
|
||||
Optionally, an override can be provided to use a specific exchanger for a
|
||||
given auth scheme.
|
||||
|
||||
Example (common case):
|
||||
```
|
||||
exchanger = AutoAuthCredentialExchanger()
|
||||
auth_credential = exchanger.exchange_credential(
|
||||
auth_scheme=service_account_scheme,
|
||||
auth_credential=service_account_credential,
|
||||
)
|
||||
# Returns an oauth token in the form of a bearer token.
|
||||
```
|
||||
|
||||
Example (use CustomAuthExchanger for OAuth2):
|
||||
```
|
||||
exchanger = AutoAuthCredentialExchanger(
|
||||
custom_exchangers={
|
||||
AuthScheme.OAUTH2: CustomAuthExchanger,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
exchangers: A dictionary mapping auth scheme to credential exchanger class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
custom_exchangers: Optional[
|
||||
Dict[str, Type[BaseAuthCredentialExchanger]]
|
||||
] = None,
|
||||
):
|
||||
"""Initializes the AutoAuthCredentialExchanger.
|
||||
|
||||
Args:
|
||||
custom_exchangers: Optional dictionary for adding or overriding auth
|
||||
exchangers. The key is the auth scheme, and the value is the credential
|
||||
exchanger class.
|
||||
"""
|
||||
self.exchangers = {
|
||||
AuthCredentialTypes.OAUTH2: OAuth2CredentialExchanger,
|
||||
AuthCredentialTypes.OPEN_ID_CONNECT: OAuth2CredentialExchanger,
|
||||
AuthCredentialTypes.SERVICE_ACCOUNT: ServiceAccountCredentialExchanger,
|
||||
}
|
||||
|
||||
if custom_exchangers:
|
||||
self.exchangers.update(custom_exchangers)
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Automatically exchanges for the credential uses the appropriate credential exchanger.
|
||||
|
||||
Args:
|
||||
auth_scheme (AuthScheme): The security scheme.
|
||||
auth_credential (AuthCredential): Optional. The authentication
|
||||
credential.
|
||||
|
||||
Returns: (AuthCredential)
|
||||
A new AuthCredential object containing the exchanged credential.
|
||||
|
||||
"""
|
||||
if not auth_credential:
|
||||
return None
|
||||
|
||||
exchanger_class = self.exchangers.get(
|
||||
auth_credential.auth_type if auth_credential else None
|
||||
)
|
||||
|
||||
if not exchanger_class:
|
||||
return auth_credential
|
||||
|
||||
exchanger = exchanger_class()
|
||||
return exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Optional
|
||||
|
||||
from .....auth.auth_credential import (
|
||||
AuthCredential,
|
||||
)
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
|
||||
|
||||
class AuthCredentialMissingError(Exception):
|
||||
"""Exception raised when required authentication credentials are missing."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BaseAuthCredentialExchanger:
|
||||
"""Base class for authentication credential exchangers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the provided authentication credential for a usable token/credential.
|
||||
|
||||
Args:
|
||||
auth_scheme: The security scheme.
|
||||
auth_credential: The authentication credential.
|
||||
|
||||
Returns:
|
||||
An updated AuthCredential object containing the fetched credential.
|
||||
For simple schemes like API key, it may return the original credential
|
||||
if no exchange is needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement exchange_credential.")
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Credential fetcher for OpenID Connect."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_credential import HttpAuth
|
||||
from .....auth.auth_credential import HttpCredentials
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .....auth.auth_schemes import AuthSchemeType
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
|
||||
|
||||
class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Fetches credentials for OAuth2 and OpenID Connect."""
|
||||
|
||||
def _check_scheme_credential_type(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
):
|
||||
if not auth_credential:
|
||||
raise ValueError(
|
||||
"auth_credential is empty. Please create AuthCredential using"
|
||||
" OAuth2Auth."
|
||||
)
|
||||
|
||||
if auth_scheme.type_ not in (
|
||||
AuthSchemeType.openIdConnect,
|
||||
AuthSchemeType.oauth2,
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid security scheme, expect AuthSchemeType.openIdConnect or "
|
||||
f"AuthSchemeType.oauth2 auth scheme, but got {auth_scheme.type_}"
|
||||
)
|
||||
|
||||
if not auth_credential.oauth2 and not auth_credential.http:
|
||||
raise ValueError(
|
||||
"auth_credential is not configured with oauth2. Please"
|
||||
" create AuthCredential and set OAuth2Auth."
|
||||
)
|
||||
|
||||
def generate_auth_token(
|
||||
self,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Generates an auth token from the authorization response.
|
||||
|
||||
Args:
|
||||
auth_scheme: The OpenID Connect or OAuth2 auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP bearer access token. If the
|
||||
HTTO bearer token cannot be generated, return the origianl credential
|
||||
"""
|
||||
|
||||
if "access_token" not in auth_credential.oauth2.token:
|
||||
return auth_credential
|
||||
|
||||
# Return the access token as a bearer token.
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(
|
||||
token=auth_credential.oauth2.token["access_token"]
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the OpenID Connect auth credential for an access token or an auth URI.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP Bearer access token.
|
||||
|
||||
Raises:
|
||||
ValueError: If the auth scheme or auth credential is invalid.
|
||||
"""
|
||||
# TODO(cheliu): Implement token refresh flow
|
||||
|
||||
self._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
|
||||
# If token is already HTTPBearer token, do nothing assuming that this token
|
||||
# is valid.
|
||||
if auth_credential.http:
|
||||
return auth_credential
|
||||
|
||||
# If access token is exchanged, exchange a HTTPBearer token.
|
||||
if auth_credential.oauth2.token:
|
||||
return self.generate_auth_token(auth_credential)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Credential fetcher for Google Service Account."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import google.auth
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import google.oauth2.credentials
|
||||
|
||||
from .....auth.auth_credential import (
|
||||
AuthCredential,
|
||||
AuthCredentialTypes,
|
||||
HttpAuth,
|
||||
HttpCredentials,
|
||||
)
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import AuthCredentialMissingError, BaseAuthCredentialExchanger
|
||||
|
||||
|
||||
class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Fetches credentials for Google Service Account.
|
||||
|
||||
Uses the default service credential if `use_default_credential = True`.
|
||||
Otherwise, uses the service account credential provided in the auth
|
||||
credential.
|
||||
"""
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the service account auth credential for an access token.
|
||||
|
||||
If auth_credential contains a service account credential, it will be used
|
||||
to fetch an access token. Otherwise, the default service credential will be
|
||||
used for fetching an access token.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential in HTTPBearer format, containing the access token.
|
||||
"""
|
||||
if (
|
||||
auth_credential is None
|
||||
or auth_credential.service_account is None
|
||||
or (
|
||||
auth_credential.service_account.service_account_credential is None
|
||||
and not auth_credential.service_account.use_default_credential
|
||||
)
|
||||
):
|
||||
raise AuthCredentialMissingError(
|
||||
"Service account credentials are missing. Please provide them, or set"
|
||||
" `use_default_credential = True` to use application default"
|
||||
" credential in a hosted service like Cloud Run."
|
||||
)
|
||||
|
||||
try:
|
||||
if auth_credential.service_account.use_default_credential:
|
||||
credentials, _ = google.auth.default()
|
||||
else:
|
||||
config = auth_credential.service_account
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
config.service_account_credential.model_dump(), scopes=config.scopes
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credentials.token),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
except Exception as e:
|
||||
raise AuthCredentialMissingError(
|
||||
f"Failed to exchange service account token: {e}"
|
||||
) from e
|
||||
19
src/google/adk/tools/openapi_tool/common/__init__.py
Normal file
19
src/google/adk/tools/openapi_tool/common/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import common
|
||||
|
||||
__all__ = [
|
||||
'common',
|
||||
]
|
||||
300
src/google/adk/tools/openapi_tool/common/common.py
Normal file
300
src/google/adk/tools/openapi_tool/common/common.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import keyword
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_serializer
|
||||
|
||||
|
||||
def to_snake_case(text: str) -> str:
|
||||
"""Converts a string into snake_case.
|
||||
|
||||
Handles lowerCamelCase, UpperCamelCase, or space-separated case, acronyms
|
||||
(e.g., "REST API") and consecutive uppercase letters correctly. Also handles
|
||||
mixed cases with and without spaces.
|
||||
|
||||
Examples:
|
||||
```
|
||||
to_snake_case('camelCase') -> 'camel_case'
|
||||
to_snake_case('UpperCamelCase') -> 'upper_camel_case'
|
||||
to_snake_case('space separated') -> 'space_separated'
|
||||
```
|
||||
|
||||
Args:
|
||||
text: The input string.
|
||||
|
||||
Returns:
|
||||
The snake_case version of the string.
|
||||
"""
|
||||
|
||||
# Handle spaces and non-alphanumeric characters (replace with underscores)
|
||||
text = re.sub(r'[^a-zA-Z0-9]+', '_', text)
|
||||
|
||||
# Insert underscores before uppercase letters (handling both CamelCases)
|
||||
text = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', text) # lowerCamelCase
|
||||
text = re.sub(
|
||||
r'([A-Z]+)([A-Z][a-z])', r'\1_\2', text
|
||||
) # UpperCamelCase and acronyms
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove consecutive underscores (clean up extra underscores)
|
||||
text = re.sub(r'_+', '_', text)
|
||||
|
||||
# Remove leading and trailing underscores
|
||||
text = text.strip('_')
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def rename_python_keywords(s: str, prefix: str = 'param_') -> str:
|
||||
"""Renames Python keywords by adding a prefix.
|
||||
|
||||
Example:
|
||||
```
|
||||
rename_python_keywords('if') -> 'param_if'
|
||||
rename_python_keywords('for') -> 'param_for'
|
||||
```
|
||||
|
||||
Args:
|
||||
s: The input string.
|
||||
prefix: The prefix to add to the keyword.
|
||||
|
||||
Returns:
|
||||
The renamed string.
|
||||
"""
|
||||
if keyword.iskeyword(s):
|
||||
return prefix + s
|
||||
return s
|
||||
|
||||
|
||||
class ApiParameter(BaseModel):
|
||||
"""Data class representing a function parameter."""
|
||||
|
||||
original_name: str
|
||||
param_location: str
|
||||
param_schema: Union[str, Schema]
|
||||
description: Optional[str] = ''
|
||||
py_name: Optional[str] = ''
|
||||
type_value: type[Any] = Field(default=None, init_var=False)
|
||||
type_hint: str = Field(default=None, init_var=False)
|
||||
|
||||
def model_post_init(self, _: Any):
|
||||
self.py_name = (
|
||||
self.py_name
|
||||
if self.py_name
|
||||
else rename_python_keywords(to_snake_case(self.original_name))
|
||||
)
|
||||
if isinstance(self.param_schema, str):
|
||||
self.param_schema = Schema.model_validate_json(self.param_schema)
|
||||
|
||||
self.description = self.description or self.param_schema.description or ''
|
||||
self.type_value = TypeHintHelper.get_type_value(self.param_schema)
|
||||
self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
|
||||
return self
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return {
|
||||
'original_name': self.original_name,
|
||||
'param_location': self.param_location,
|
||||
'param_schema': self.param_schema,
|
||||
'description': self.description,
|
||||
'py_name': self.py_name,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.py_name}: {self.type_hint}'
|
||||
|
||||
def to_arg_string(self):
|
||||
"""Converts the parameter to an argument string for function call."""
|
||||
return f'{self.py_name}={self.py_name}'
|
||||
|
||||
def to_dict_property(self):
|
||||
"""Converts the parameter to a key:value string for dict property."""
|
||||
return f'"{self.py_name}": {self.py_name}'
|
||||
|
||||
def to_pydoc_string(self):
|
||||
"""Converts the parameter to a PyDoc parameter docstr."""
|
||||
return PydocHelper.generate_param_doc(self)
|
||||
|
||||
|
||||
class TypeHintHelper:
|
||||
"""Helper class for generating type hints."""
|
||||
|
||||
@staticmethod
|
||||
def get_type_value(schema: Schema) -> Any:
|
||||
"""Generates the Python type value for a given parameter."""
|
||||
param_type = schema.type if schema.type else Any
|
||||
|
||||
if param_type == 'integer':
|
||||
return int
|
||||
elif param_type == 'number':
|
||||
return float
|
||||
elif param_type == 'boolean':
|
||||
return bool
|
||||
elif param_type == 'string':
|
||||
return str
|
||||
elif param_type == 'array':
|
||||
items_type = Any
|
||||
if schema.items and schema.items.type:
|
||||
items_type = schema.items.type
|
||||
|
||||
if items_type == 'object':
|
||||
return List[Dict[str, Any]]
|
||||
else:
|
||||
type_map = {
|
||||
'integer': int,
|
||||
'number': float,
|
||||
'boolean': bool,
|
||||
'string': str,
|
||||
'object': Dict[str, Any],
|
||||
'array': List[Any],
|
||||
}
|
||||
return List[type_map.get(items_type, 'Any')]
|
||||
elif param_type == 'object':
|
||||
return Dict[str, Any]
|
||||
else:
|
||||
return Any
|
||||
|
||||
@staticmethod
|
||||
def get_type_hint(schema: Schema) -> str:
|
||||
"""Generates the Python type in string for a given parameter."""
|
||||
param_type = schema.type if schema.type else 'Any'
|
||||
|
||||
if param_type == 'integer':
|
||||
return 'int'
|
||||
elif param_type == 'number':
|
||||
return 'float'
|
||||
elif param_type == 'boolean':
|
||||
return 'bool'
|
||||
elif param_type == 'string':
|
||||
return 'str'
|
||||
elif param_type == 'array':
|
||||
items_type = 'Any'
|
||||
if schema.items and schema.items.type:
|
||||
items_type = schema.items.type
|
||||
|
||||
if items_type == 'object':
|
||||
return 'List[Dict[str, Any]]'
|
||||
else:
|
||||
type_map = {
|
||||
'integer': 'int',
|
||||
'number': 'float',
|
||||
'boolean': 'bool',
|
||||
'string': 'str',
|
||||
}
|
||||
return f"List[{type_map.get(items_type, 'Any')}]"
|
||||
elif param_type == 'object':
|
||||
return 'Dict[str, Any]'
|
||||
else:
|
||||
return 'Any'
|
||||
|
||||
|
||||
class PydocHelper:
|
||||
"""Helper class for generating PyDoc strings."""
|
||||
|
||||
@staticmethod
|
||||
def generate_param_doc(
|
||||
param: ApiParameter,
|
||||
) -> str:
|
||||
"""Generates a parameter documentation string.
|
||||
|
||||
Args:
|
||||
param: ApiParameter - The parameter to generate the documentation for.
|
||||
|
||||
Returns:
|
||||
str: The generated parameter Python documentation string.
|
||||
"""
|
||||
description = param.description.strip() if param.description else ''
|
||||
param_doc = f'{param.py_name} ({param.type_hint}): {description}'
|
||||
|
||||
if param.param_schema.type == 'object':
|
||||
properties = param.param_schema.properties
|
||||
if properties:
|
||||
param_doc += ' Object properties:\n'
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_desc = prop_details.description or ''
|
||||
prop_type = TypeHintHelper.get_type_hint(prop_details)
|
||||
param_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
|
||||
|
||||
return param_doc
|
||||
|
||||
@staticmethod
|
||||
def generate_return_doc(responses: Dict[str, Response]) -> str:
|
||||
"""Generates a return value documentation string.
|
||||
|
||||
Args:
|
||||
responses: Dict[str, TypedDict[Response]] - Response in an OpenAPI
|
||||
Operation
|
||||
|
||||
Returns:
|
||||
str: The generated return value Python documentation string.
|
||||
"""
|
||||
return_doc = ''
|
||||
|
||||
# Only consider 2xx responses for return type hinting.
|
||||
# Returns the 2xx response with the smallest status code number and with
|
||||
# content defined.
|
||||
sorted_responses = sorted(responses.items(), key=lambda item: int(item[0]))
|
||||
qualified_response = next(
|
||||
filter(
|
||||
lambda r: r[0].startswith('2') and r[1].content,
|
||||
sorted_responses,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not qualified_response:
|
||||
return ''
|
||||
response_details = qualified_response[1]
|
||||
|
||||
description = (response_details.description or '').strip()
|
||||
content = response_details.content or {}
|
||||
|
||||
# Generate return type hint and properties for the first response type.
|
||||
# TODO(cheliu): Handle multiple content types.
|
||||
for _, schema_details in content.items():
|
||||
schema = schema_details.schema_ or {}
|
||||
|
||||
# Use a dummy Parameter object for return type hinting.
|
||||
dummy_param = ApiParameter(
|
||||
original_name='', param_location='', param_schema=schema
|
||||
)
|
||||
return_doc = f'Returns ({dummy_param.type_hint}): {description}'
|
||||
|
||||
response_type = schema.type or 'Any'
|
||||
if response_type != 'object':
|
||||
break
|
||||
properties = schema.properties
|
||||
if not properties:
|
||||
break
|
||||
return_doc += ' Object properties:\n'
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_desc = prop_details.description or ''
|
||||
prop_type = TypeHintHelper.get_type_hint(prop_details)
|
||||
return_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
|
||||
break
|
||||
|
||||
return return_doc
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
|
||||
from .openapi_toolset import OpenAPIToolset
|
||||
from .operation_parser import OperationParser
|
||||
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
__all__ = [
|
||||
'OpenApiSpecParser',
|
||||
'OperationEndpoint',
|
||||
'ParsedOperation',
|
||||
'OpenAPIToolset',
|
||||
'OperationParser',
|
||||
'RestApiTool',
|
||||
'to_gemini_schema',
|
||||
'snake_to_lower_camel',
|
||||
'AuthPreparationState',
|
||||
'ToolAuthHandler',
|
||||
]
|
||||
@@ -0,0 +1,231 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .operation_parser import OperationParser
|
||||
|
||||
|
||||
class OperationEndpoint(BaseModel):
|
||||
base_url: str
|
||||
path: str
|
||||
method: str
|
||||
|
||||
|
||||
class ParsedOperation(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
endpoint: OperationEndpoint
|
||||
operation: Operation
|
||||
parameters: List[ApiParameter]
|
||||
return_value: ApiParameter
|
||||
auth_scheme: Optional[AuthScheme] = None
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
additional_context: Optional[Any] = None
|
||||
|
||||
|
||||
class OpenApiSpecParser:
|
||||
"""Generates Python code, JSON schema, and callables for an OpenAPI operation.
|
||||
|
||||
This class takes an OpenApiOperation object and provides methods to generate:
|
||||
1. A string representation of a Python function that handles the operation.
|
||||
2. A JSON schema representing the input parameters of the operation.
|
||||
3. A callable Python object (a function) that can execute the operation.
|
||||
"""
|
||||
|
||||
def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
|
||||
"""Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
|
||||
|
||||
ParsedOperation objects are further used for generating RestApiTool.
|
||||
|
||||
Args:
|
||||
openapi_spec_dict: A dictionary representing the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
A list of ParsedOperation objects.
|
||||
"""
|
||||
|
||||
openapi_spec_dict = self._resolve_references(openapi_spec_dict)
|
||||
operations = self._collect_operations(openapi_spec_dict)
|
||||
return operations
|
||||
|
||||
def _collect_operations(
|
||||
self, openapi_spec: Dict[str, Any]
|
||||
) -> List[ParsedOperation]:
|
||||
"""Collects operations from an OpenAPI spec."""
|
||||
operations = []
|
||||
|
||||
# Taking first server url, or default to empty string if not present
|
||||
base_url = ""
|
||||
if openapi_spec.get("servers"):
|
||||
base_url = openapi_spec["servers"][0].get("url", "")
|
||||
|
||||
# Get global security scheme (if any)
|
||||
global_scheme_name = None
|
||||
if openapi_spec.get("security"):
|
||||
# Use first scheme by default.
|
||||
scheme_names = list(openapi_spec["security"][0].keys())
|
||||
global_scheme_name = scheme_names[0] if scheme_names else None
|
||||
|
||||
auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
|
||||
|
||||
for path, path_item in openapi_spec.get("paths", {}).items():
|
||||
if path_item is None:
|
||||
continue
|
||||
|
||||
for method in (
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"head",
|
||||
"options",
|
||||
"trace",
|
||||
):
|
||||
operation_dict = path_item.get(method)
|
||||
if operation_dict is None:
|
||||
continue
|
||||
|
||||
# If operation ID is missing, assign an operation id based on path
|
||||
# and method
|
||||
if "operationId" not in operation_dict:
|
||||
temp_id = to_snake_case(f"{path}_{method}")
|
||||
operation_dict["operationId"] = temp_id
|
||||
|
||||
url = OperationEndpoint(base_url=base_url, path=path, method=method)
|
||||
operation = Operation.model_validate(operation_dict)
|
||||
operation_parser = OperationParser(operation)
|
||||
|
||||
# Check for operation-specific auth scheme
|
||||
auth_scheme_name = operation_parser.get_auth_scheme_name()
|
||||
auth_scheme_name = (
|
||||
auth_scheme_name if auth_scheme_name else global_scheme_name
|
||||
)
|
||||
auth_scheme = (
|
||||
auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
|
||||
)
|
||||
|
||||
parsed_op = ParsedOperation(
|
||||
name=operation_parser.get_function_name(),
|
||||
description=operation.description or operation.summary or "",
|
||||
endpoint=url,
|
||||
operation=operation,
|
||||
parameters=operation_parser.get_parameters(),
|
||||
return_value=operation_parser.get_return_value(),
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=None, # Placeholder
|
||||
additional_context={},
|
||||
)
|
||||
operations.append(parsed_op)
|
||||
|
||||
return operations
|
||||
|
||||
def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively resolves all $ref references in an OpenAPI specification.
|
||||
|
||||
Handles circular references correctly.
|
||||
|
||||
Args:
|
||||
openapi_spec: A dictionary representing the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the OpenAPI specification with all references
|
||||
resolved.
|
||||
"""
|
||||
|
||||
openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
|
||||
resolved_cache = {} # Cache resolved references
|
||||
|
||||
def resolve_ref(ref_string, current_doc):
|
||||
"""Resolves a single $ref string."""
|
||||
parts = ref_string.split("/")
|
||||
if parts[0] != "#":
|
||||
raise ValueError(f"External references not supported: {ref_string}")
|
||||
|
||||
current = current_doc
|
||||
for part in parts[1:]:
|
||||
if part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return None # Reference not found
|
||||
return current
|
||||
|
||||
def recursive_resolve(obj, current_doc, seen_refs=None):
|
||||
"""Recursively resolves references, handling circularity.
|
||||
|
||||
Args:
|
||||
obj: The object to traverse.
|
||||
current_doc: Document to search for refs.
|
||||
seen_refs: A set to track already-visited references (for circularity
|
||||
detection).
|
||||
|
||||
Returns:
|
||||
The resolved object.
|
||||
"""
|
||||
if seen_refs is None:
|
||||
seen_refs = set() # Initialize the set if it's the first call
|
||||
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and isinstance(obj["$ref"], str):
|
||||
ref_string = obj["$ref"]
|
||||
|
||||
# Check for circularity
|
||||
if ref_string in seen_refs and ref_string not in resolved_cache:
|
||||
# Circular reference detected! Return a *copy* of the object,
|
||||
# but *without* the $ref. This breaks the cycle while
|
||||
# still maintaining the overall structure.
|
||||
return {k: v for k, v in obj.items() if k != "$ref"}
|
||||
|
||||
seen_refs.add(ref_string) # Add the reference to the set
|
||||
|
||||
# Check if we have a cached resolved value
|
||||
if ref_string in resolved_cache:
|
||||
return copy.deepcopy(resolved_cache[ref_string])
|
||||
|
||||
resolved_value = resolve_ref(ref_string, current_doc)
|
||||
if resolved_value is not None:
|
||||
# Recursively resolve the *resolved* value,
|
||||
# passing along the 'seen_refs' set
|
||||
resolved_value = recursive_resolve(
|
||||
resolved_value, current_doc, seen_refs
|
||||
)
|
||||
resolved_cache[ref_string] = resolved_value
|
||||
return copy.deepcopy(resolved_value) # return the cached result
|
||||
else:
|
||||
return obj # return original if no resolved value.
|
||||
|
||||
else:
|
||||
new_dict = {}
|
||||
for key, value in obj.items():
|
||||
new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
|
||||
return new_dict
|
||||
|
||||
elif isinstance(obj, list):
|
||||
return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
return recursive_resolve(openapi_spec, openapi_spec)
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from .openapi_spec_parser import OpenApiSpecParser
|
||||
from .rest_api_tool import RestApiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAPIToolset:
|
||||
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
||||
|
||||
Usage:
|
||||
```
|
||||
# Initialize OpenAPI toolset from a spec string.
|
||||
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
|
||||
spec_str_type="json")
|
||||
# Or, initialize OpenAPI toolset from a spec dictionary.
|
||||
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
|
||||
|
||||
# Add all tools to an agent.
|
||||
agent = Agent(
|
||||
tools=[*openapi_toolset.get_tools()]
|
||||
)
|
||||
# Or, add a single tool to an agent.
|
||||
agent = Agent(
|
||||
tools=[openapi_toolset.get_tool('tool_name')]
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
spec_dict: Optional[Dict[str, Any]] = None,
|
||||
spec_str: Optional[str] = None,
|
||||
spec_str_type: Literal["json", "yaml"] = "json",
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
):
|
||||
"""Initializes the OpenAPIToolset.
|
||||
|
||||
Usage:
|
||||
```
|
||||
# Initialize OpenAPI toolset from a spec string.
|
||||
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
|
||||
spec_str_type="json")
|
||||
# Or, initialize OpenAPI toolset from a spec dictionary.
|
||||
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
|
||||
|
||||
# Add all tools to an agent.
|
||||
agent = Agent(
|
||||
tools=[*openapi_toolset.get_tools()]
|
||||
)
|
||||
# Or, add a single tool to an agent.
|
||||
agent = Agent(
|
||||
tools=[openapi_toolset.get_tool('tool_name')]
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
spec_dict: The OpenAPI spec dictionary. If provided, it will be used
|
||||
instead of loading the spec from a string.
|
||||
spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
|
||||
when spec_dict is not provided.
|
||||
spec_str_type: The type of the OpenAPI spec string. Can be "json" or
|
||||
"yaml".
|
||||
auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
|
||||
helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
auth_credential: The auth credential to use for all tools. Use
|
||||
AuthCredential or use helpers in
|
||||
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
"""
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
|
||||
def _configure_auth_all(
|
||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||
):
|
||||
"""Configure auth scheme and credential for all tools."""
|
||||
|
||||
for tool in self.tools:
|
||||
if auth_scheme:
|
||||
tool.configure_auth_scheme(auth_scheme)
|
||||
if auth_credential:
|
||||
tool.configure_auth_credential(auth_credential)
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||
"""Get a tool by name."""
|
||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
||||
return next(matching_tool, None)
|
||||
|
||||
def _load_spec(
|
||||
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads the OpenAPI spec string into adictionary."""
|
||||
if spec_type == "json":
|
||||
return json.loads(spec_str)
|
||||
elif spec_type == "yaml":
|
||||
return yaml.safe_load(spec_str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported spec type: {spec_type}")
|
||||
|
||||
def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
|
||||
"""Parse OpenAPI spec into a list of RestApiTool."""
|
||||
operations = OpenApiSpecParser().parse(openapi_spec_dict)
|
||||
|
||||
tools = []
|
||||
for o in operations:
|
||||
tool = RestApiTool.from_parsed_operation(o)
|
||||
logger.info("Parsed tool: %s", tool.name)
|
||||
tools.append(tool)
|
||||
return tools
|
||||
@@ -0,0 +1,260 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import Schema
|
||||
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import PydocHelper
|
||||
from ..common.common import to_snake_case
|
||||
|
||||
|
||||
class OperationParser:
|
||||
"""Generates parameters for Python functions from an OpenAPI operation.
|
||||
|
||||
This class processes an OpenApiOperation object and provides helper methods
|
||||
to extract information needed to generate Python function declarations,
|
||||
docstrings, signatures, and JSON schemas. It handles parameter processing,
|
||||
name deduplication, and type hint generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
|
||||
):
|
||||
"""Initializes the OperationParser with an OpenApiOperation.
|
||||
|
||||
Args:
|
||||
operation: The OpenApiOperation object or a dictionary to process.
|
||||
should_parse: Whether to parse the operation during initialization.
|
||||
"""
|
||||
if isinstance(operation, dict):
|
||||
self.operation = Operation.model_validate(operation)
|
||||
elif isinstance(operation, str):
|
||||
self.operation = Operation.model_validate_json(operation)
|
||||
else:
|
||||
self.operation = operation
|
||||
|
||||
self.params: List[ApiParameter] = []
|
||||
self.return_value: Optional[ApiParameter] = None
|
||||
if should_parse:
|
||||
self._process_operation_parameters()
|
||||
self._process_request_body()
|
||||
self._process_return_value()
|
||||
self._dedupe_param_names()
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
operation: Union[Operation, Dict[str, Any]],
|
||||
params: List[ApiParameter],
|
||||
return_value: Optional[ApiParameter] = None,
|
||||
) -> 'OperationParser':
|
||||
parser = cls(operation, should_parse=False)
|
||||
parser.params = params
|
||||
parser.return_value = return_value
|
||||
return parser
|
||||
|
||||
def _process_operation_parameters(self):
|
||||
"""Processes parameters from the OpenAPI operation."""
|
||||
parameters = self.operation.parameters or []
|
||||
for param in parameters:
|
||||
if isinstance(param, Parameter):
|
||||
original_name = param.name
|
||||
description = param.description or ''
|
||||
location = param.in_ or ''
|
||||
schema = param.schema_ or {} # Use schema_ instead of .schema
|
||||
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=original_name,
|
||||
param_location=location,
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_request_body(self):
|
||||
"""Processes the request body from the OpenAPI operation."""
|
||||
request_body = self.operation.requestBody
|
||||
if not request_body:
|
||||
return
|
||||
|
||||
content = request_body.content or {}
|
||||
if not content:
|
||||
return
|
||||
|
||||
# If request body is an object, expand the properties as parameters
|
||||
for _, media_type_object in content.items():
|
||||
schema = media_type_object.schema_ or {}
|
||||
description = request_body.description or ''
|
||||
|
||||
if schema and schema.type == 'object':
|
||||
for prop_name, prop_details in schema.properties.items():
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=prop_name,
|
||||
param_location='body',
|
||||
param_schema=prop_details,
|
||||
description=prop_details.description,
|
||||
)
|
||||
)
|
||||
|
||||
elif schema and schema.type == 'array':
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name='array',
|
||||
param_location='body',
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.params.append(
|
||||
# Empty name for unnamed body param
|
||||
ApiParameter(
|
||||
original_name='',
|
||||
param_location='body',
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
break # Process first mime type only
|
||||
|
||||
def _dedupe_param_names(self):
|
||||
"""Deduplicates parameter names to avoid conflicts."""
|
||||
params_cnt = {}
|
||||
for param in self.params:
|
||||
name = param.py_name
|
||||
if name not in params_cnt:
|
||||
params_cnt[name] = 0
|
||||
else:
|
||||
params_cnt[name] += 1
|
||||
param.py_name = f'{name}_{params_cnt[name] -1}'
|
||||
|
||||
def _process_return_value(self) -> Parameter:
|
||||
"""Returns a Parameter object representing the return type."""
|
||||
responses = self.operation.responses or {}
|
||||
# Default to Any if no 2xx response or if schema is missing
|
||||
return_schema = Schema(type='Any')
|
||||
|
||||
# Take the 20x response with the smallest response code.
|
||||
valid_codes = list(
|
||||
filter(lambda k: k.startswith('2'), list(responses.keys()))
|
||||
)
|
||||
min_20x_status_code = min(valid_codes) if valid_codes else None
|
||||
|
||||
if min_20x_status_code and responses[min_20x_status_code].content:
|
||||
content = responses[min_20x_status_code].content
|
||||
for mime_type in content:
|
||||
if content[mime_type].schema_:
|
||||
return_schema = content[mime_type].schema_
|
||||
break
|
||||
|
||||
self.return_value = ApiParameter(
|
||||
original_name='',
|
||||
param_location='',
|
||||
param_schema=return_schema,
|
||||
)
|
||||
|
||||
def get_function_name(self) -> str:
|
||||
"""Returns the generated function name."""
|
||||
operation_id = self.operation.operationId
|
||||
if not operation_id:
|
||||
raise ValueError('Operation ID is missing')
|
||||
return to_snake_case(operation_id)[:60]
|
||||
|
||||
def get_return_type_hint(self) -> str:
|
||||
"""Returns the return type hint string (like 'str', 'int', etc.)."""
|
||||
return self.return_value.type_hint
|
||||
|
||||
def get_return_type_value(self) -> Any:
|
||||
"""Returns the return type value (like str, int, List[str], etc.)."""
|
||||
return self.return_value.type_value
|
||||
|
||||
def get_parameters(self) -> List[ApiParameter]:
|
||||
"""Returns the list of Parameter objects."""
|
||||
return self.params
|
||||
|
||||
def get_return_value(self) -> ApiParameter:
|
||||
"""Returns the list of Parameter objects."""
|
||||
return self.return_value
|
||||
|
||||
def get_auth_scheme_name(self) -> str:
|
||||
"""Returns the name of the auth scheme for this operation from the spec."""
|
||||
if self.operation.security:
|
||||
scheme_name = list(self.operation.security[0].keys())[0]
|
||||
return scheme_name
|
||||
return ''
|
||||
|
||||
def get_pydoc_string(self) -> str:
|
||||
"""Returns the generated PyDoc string."""
|
||||
pydoc_params = [param.to_pydoc_string() for param in self.params]
|
||||
pydoc_description = (
|
||||
self.operation.summary or self.operation.description or ''
|
||||
)
|
||||
pydoc_return = PydocHelper.generate_return_doc(
|
||||
self.operation.responses or {}
|
||||
)
|
||||
pydoc_arg_list = chr(10).join(
|
||||
f' {param_doc}' for param_doc in pydoc_params
|
||||
)
|
||||
return dedent(f"""
|
||||
\"\"\"{pydoc_description}
|
||||
|
||||
Args:
|
||||
{pydoc_arg_list}
|
||||
|
||||
{pydoc_return}
|
||||
\"\"\"
|
||||
""").strip()
|
||||
|
||||
def get_json_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the JSON schema for the function arguments."""
|
||||
properties = {
|
||||
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
|
||||
for p in self.params
|
||||
}
|
||||
return {
|
||||
'properties': properties,
|
||||
'required': [p.py_name for p in self.params],
|
||||
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
def get_signature_parameters(self) -> List[inspect.Parameter]:
|
||||
"""Returns a list of inspect.Parameter objects for the function."""
|
||||
return [
|
||||
inspect.Parameter(
|
||||
param.py_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=param.type_value,
|
||||
)
|
||||
for param in self.params
|
||||
]
|
||||
|
||||
def get_annotations(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of parameter annotations for the function."""
|
||||
annotations = {p.py_name: p.type_value for p in self.params}
|
||||
annotations['return'] = self.get_return_type_value()
|
||||
return annotations
|
||||
@@ -0,0 +1,496 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....tools import BaseTool
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.auth_helpers import credential_to_param
|
||||
from ..auth.auth_helpers import dict_to_auth_scheme
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .openapi_spec_parser import OperationEndpoint
|
||||
from .openapi_spec_parser import ParsedOperation
|
||||
from .operation_parser import OperationParser
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
|
||||
def snake_to_lower_camel(snake_case_string: str):
|
||||
"""Converts a snake_case string to a lower_camel_case string.
|
||||
|
||||
Args:
|
||||
snake_case_string: The input snake_case string.
|
||||
|
||||
Returns:
|
||||
The lower_camel_case string.
|
||||
"""
|
||||
if "_" not in snake_case_string:
|
||||
return snake_case_string
|
||||
|
||||
return "".join([
|
||||
s.lower() if i == 0 else s.capitalize()
|
||||
for i, s in enumerate(snake_case_string.split("_"))
|
||||
])
|
||||
|
||||
|
||||
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dictionary.
|
||||
|
||||
Returns:
|
||||
A Pydantic Schema object. Returns None if input is None.
|
||||
Raises TypeError if input is not a dict.
|
||||
"""
|
||||
if openapi_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(openapi_schema, dict):
|
||||
raise TypeError("openapi_schema must be a dictionary")
|
||||
|
||||
pydantic_schema_data = {}
|
||||
|
||||
# Adding this to force adding a type to an empty dict
|
||||
# This avoid "... one_of or any_of must specify a type" error
|
||||
if not openapi_schema.get("type"):
|
||||
openapi_schema["type"] = "object"
|
||||
|
||||
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
|
||||
# See b/385165182
|
||||
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
|
||||
"properties"
|
||||
):
|
||||
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
|
||||
|
||||
for key, value in openapi_schema.items():
|
||||
snake_case_key = to_snake_case(key)
|
||||
# Check if the snake_case_key exists in the Schema model's fields.
|
||||
if snake_case_key in Schema.model_fields:
|
||||
if snake_case_key in ["title", "default", "format"]:
|
||||
# Ignore these fields as Gemini backend doesn't recognize them, and will
|
||||
# throw exception if they appear in the schema.
|
||||
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
||||
# supported for STRING type
|
||||
continue
|
||||
if snake_case_key == "properties" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = {
|
||||
k: to_gemini_schema(v) for k, v in value.items()
|
||||
}
|
||||
elif snake_case_key == "items" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
elif snake_case_key == "any_of" and isinstance(value, list):
|
||||
pydantic_schema_data[snake_case_key] = [
|
||||
to_gemini_schema(item) for item in value
|
||||
]
|
||||
# Important: Handle cases where the OpenAPI schema might contain lists
|
||||
# or other structures that need to be recursively processed.
|
||||
elif isinstance(value, list) and snake_case_key not in (
|
||||
"enum",
|
||||
"required",
|
||||
"property_ordering",
|
||||
):
|
||||
new_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_list.append(to_gemini_schema(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
pydantic_schema_data[snake_case_key] = new_list
|
||||
elif isinstance(value, dict) and snake_case_key not in ("properties"):
|
||||
# Handle dictionary which is neither properties or items
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
else:
|
||||
# Simple value assignment (int, str, bool, etc.)
|
||||
pydantic_schema_data[snake_case_key] = value
|
||||
|
||||
return Schema(**pydantic_schema_data)
|
||||
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
class RestApiTool(BaseTool):
|
||||
"""A generic tool that interacts with a REST API.
|
||||
|
||||
* Generates request params and body
|
||||
* Attaches auth credentials to API call.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
endpoint: Union[OperationEndpoint, str],
|
||||
operation: Union[Operation, str],
|
||||
auth_scheme: Optional[Union[AuthScheme, str]] = None,
|
||||
auth_credential: Optional[Union[AuthCredential, str]] = None,
|
||||
should_parse_operation=True,
|
||||
):
|
||||
"""Initializes the RestApiTool with the given parameters.
|
||||
|
||||
To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
|
||||
Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
|
||||
auth_scheme and auth_credential.
|
||||
|
||||
Args:
|
||||
name: The name of the tool.
|
||||
description: The description of the tool.
|
||||
endpoint: Include the base_url, path, and method of the tool.
|
||||
operation: Pydantic object or a dict. Representing the OpenAPI Operation
|
||||
object
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
|
||||
auth_scheme: The auth scheme of the tool. Representing the OpenAPI
|
||||
SecurityScheme object
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
|
||||
auth_credential: The authentication credential of the tool.
|
||||
should_parse_operation: Whether to parse the operation.
|
||||
"""
|
||||
# Gemini restrict the length of function name to be less than 64 characters
|
||||
self.name = name[:60]
|
||||
self.description = description
|
||||
self.endpoint = (
|
||||
OperationEndpoint.model_validate_json(endpoint)
|
||||
if isinstance(endpoint, str)
|
||||
else endpoint
|
||||
)
|
||||
self.operation = (
|
||||
Operation.model_validate_json(operation)
|
||||
if isinstance(operation, str)
|
||||
else operation
|
||||
)
|
||||
self.auth_credential, self.auth_scheme = None, None
|
||||
|
||||
self.configure_auth_credential(auth_credential)
|
||||
self.configure_auth_scheme(auth_scheme)
|
||||
|
||||
# Private properties
|
||||
self.credential_exchanger = AutoAuthCredentialExchanger()
|
||||
if should_parse_operation:
|
||||
self._operation_parser = OperationParser(self.operation)
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a ParsedOperation object.
|
||||
|
||||
Args:
|
||||
parsed: A ParsedOperation object.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
"""
|
||||
operation_parser = OperationParser.load(
|
||||
parsed.operation, parsed.parameters, parsed.return_value
|
||||
)
|
||||
|
||||
tool_name = to_snake_case(operation_parser.get_function_name())
|
||||
generated = cls(
|
||||
name=tool_name,
|
||||
description=parsed.operation.description
|
||||
or parsed.operation.summary
|
||||
or "",
|
||||
endpoint=parsed.endpoint,
|
||||
operation=parsed.operation,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
auth_credential=parsed.auth_credential,
|
||||
)
|
||||
generated._operation_parser = operation_parser
|
||||
return generated
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation_str(
|
||||
cls, parsed_operation_str: str
|
||||
) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a dict.
|
||||
|
||||
Args:
|
||||
parsed: A dict representation of a ParsedOperation object.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
"""
|
||||
operation = ParsedOperation.model_validate_json(parsed_operation_str)
|
||||
return RestApiTool.from_parsed_operation(operation)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self._operation_parser.get_json_schema()
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
def configure_auth_scheme(
|
||||
self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
|
||||
):
|
||||
"""Configures the authentication scheme for the API call.
|
||||
|
||||
Args:
|
||||
auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
|
||||
converted to a AuthScheme object.
|
||||
"""
|
||||
if isinstance(auth_scheme, dict):
|
||||
auth_scheme = dict_to_auth_scheme(auth_scheme)
|
||||
self.auth_scheme = auth_scheme
|
||||
|
||||
def configure_auth_credential(
|
||||
self, auth_credential: Optional[Union[AuthCredential, str]] = None
|
||||
):
|
||||
"""Configures the authentication credential for the API call.
|
||||
|
||||
Args:
|
||||
auth_credential: AuthCredential|dict - The authentication credential.
|
||||
The dict is converted to an AuthCredential object.
|
||||
"""
|
||||
if isinstance(auth_credential, str):
|
||||
auth_credential = AuthCredential.model_validate_json(auth_credential)
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
def _prepare_auth_request_params(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: AuthCredential,
|
||||
) -> Tuple[List[ApiParameter], Dict[str, Any]]:
|
||||
# Handle Authentication
|
||||
if not auth_scheme or not auth_credential:
|
||||
return
|
||||
|
||||
return credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
def _prepare_request_params(
|
||||
self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepares the request parameters for the API call.
|
||||
|
||||
Args:
|
||||
parameters: A list of ApiParameter objects representing the parameters
|
||||
for the API call.
|
||||
kwargs: The keyword arguments passed to the call function from the Tool
|
||||
caller.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the request parameters for the API call. This
|
||||
initializes a requests.request() call.
|
||||
|
||||
Example:
|
||||
self._prepare_request_params({"input_id": "test-id"})
|
||||
"""
|
||||
method = self.endpoint.method.lower()
|
||||
if not method:
|
||||
raise ValueError("Operation method not found.")
|
||||
|
||||
path_params: Dict[str, Any] = {}
|
||||
query_params: Dict[str, Any] = {}
|
||||
header_params: Dict[str, Any] = {}
|
||||
cookie_params: Dict[str, Any] = {}
|
||||
|
||||
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
|
||||
|
||||
# Fill in path, query, header and cookie parameters to the request
|
||||
for param_k, v in kwargs.items():
|
||||
param_obj = params_map.get(param_k)
|
||||
if not param_obj:
|
||||
continue # If input arg not in the ApiParameter list, ignore it.
|
||||
|
||||
original_k = param_obj.original_name
|
||||
param_location = param_obj.param_location
|
||||
|
||||
if param_location == "path":
|
||||
path_params[original_k] = v
|
||||
elif param_location == "query":
|
||||
if v:
|
||||
query_params[original_k] = v
|
||||
elif param_location == "header":
|
||||
header_params[original_k] = v
|
||||
elif param_location == "cookie":
|
||||
cookie_params[original_k] = v
|
||||
|
||||
# Construct URL
|
||||
base_url = self.endpoint.base_url or ""
|
||||
base_url = base_url[:-1] if base_url.endswith("/") else base_url
|
||||
url = f"{base_url}{self.endpoint.path.format(**path_params)}"
|
||||
|
||||
# Construct body
|
||||
body_kwargs: Dict[str, Any] = {}
|
||||
request_body = self.operation.requestBody
|
||||
if request_body:
|
||||
for mime_type, media_type_object in request_body.content.items():
|
||||
schema = media_type_object.schema_
|
||||
body_data = None
|
||||
|
||||
if schema.type == "object":
|
||||
body_data = {}
|
||||
for param in parameters:
|
||||
if param.param_location == "body" and param.py_name in kwargs:
|
||||
body_data[param.original_name] = kwargs[param.py_name]
|
||||
|
||||
elif schema.type == "array":
|
||||
for param in parameters:
|
||||
if param.param_location == "body" and param.py_name == "array":
|
||||
body_data = kwargs.get("array")
|
||||
break
|
||||
else: # like string
|
||||
for param in parameters:
|
||||
# original_name = '' indicating this param applies to the full body.
|
||||
if param.param_location == "body" and not param.original_name:
|
||||
body_data = (
|
||||
kwargs.get(param.py_name) if param.py_name in kwargs else None
|
||||
)
|
||||
break
|
||||
|
||||
if mime_type == "application/json" or mime_type.endswith("+json"):
|
||||
if body_data is not None:
|
||||
body_kwargs["json"] = body_data
|
||||
elif mime_type == "application/x-www-form-urlencoded":
|
||||
body_kwargs["data"] = body_data
|
||||
elif mime_type == "multipart/form-data":
|
||||
body_kwargs["files"] = body_data
|
||||
elif mime_type == "application/octet-stream":
|
||||
body_kwargs["data"] = body_data
|
||||
elif mime_type == "text/plain":
|
||||
body_kwargs["data"] = body_data
|
||||
|
||||
if mime_type:
|
||||
header_params["Content-Type"] = mime_type
|
||||
break # Process only the first mime_type
|
||||
|
||||
filtered_query_params: Dict[str, Any] = {
|
||||
k: v for k, v in query_params.items() if v is not None
|
||||
}
|
||||
|
||||
request_params: Dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"params": filtered_query_params,
|
||||
"headers": header_params,
|
||||
"cookies": cookie_params,
|
||||
**body_kwargs,
|
||||
}
|
||||
|
||||
return request_params
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
return self.call(args=args, tool_context=tool_context)
|
||||
|
||||
def call(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
"""Executes the REST API call.
|
||||
|
||||
Args:
|
||||
args: Keyword arguments representing the operation parameters.
|
||||
tool_context: The tool context (not used here, but required by the
|
||||
interface).
|
||||
|
||||
Returns:
|
||||
The API response as a dictionary.
|
||||
"""
|
||||
# Prepare auth credentials for the API call
|
||||
tool_auth_handler = ToolAuthHandler.from_tool_context(
|
||||
tool_context, self.auth_scheme, self.auth_credential
|
||||
)
|
||||
auth_result = tool_auth_handler.prepare_auth_credentials()
|
||||
auth_state, auth_scheme, auth_credential = (
|
||||
auth_result.state,
|
||||
auth_result.auth_scheme,
|
||||
auth_result.auth_credential,
|
||||
)
|
||||
|
||||
if auth_state == "pending":
|
||||
return {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
# Attach parameters from auth into main parameters list
|
||||
api_params, api_args = self._operation_parser.get_parameters().copy(), args
|
||||
if auth_credential:
|
||||
# Attach parameters from auth into main parameters list
|
||||
auth_param, auth_args = self._prepare_auth_request_params(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
if auth_param and auth_args:
|
||||
api_params = [auth_param] + api_params
|
||||
api_args.update(auth_args)
|
||||
|
||||
# Got all parameters. Call the API.
|
||||
request_params = self._prepare_request_params(api_params, api_args)
|
||||
response = requests.request(**request_params)
|
||||
|
||||
# Parse API response
|
||||
try:
|
||||
response.raise_for_status() # Raise HTTPError for bad responses
|
||||
return response.json() # Try to decode JSON
|
||||
except requests.exceptions.HTTPError:
|
||||
error_details = response.content.decode("utf-8")
|
||||
return {
|
||||
"error": (
|
||||
f"Tool {self.name} execution failed. Analyze this execution error"
|
||||
" and your inputs. Retry with adjustments if applicable. But"
|
||||
" make sure don't retry more than 3 times. Execution Error:"
|
||||
f" {error_details}"
|
||||
)
|
||||
}
|
||||
except ValueError:
|
||||
return {"text": response.text} # Return text if not JSON
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'RestApiTool(name="{self.name}", description="{self.description}",'
|
||||
f' endpoint="{self.endpoint}")'
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'RestApiTool(name="{self.name}", description="{self.description}",'
|
||||
f' endpoint="{self.endpoint}", operation="{self.operation}",'
|
||||
f' auth_scheme="{self.auth_scheme}",'
|
||||
f' auth_credential="{self.auth_credential}")'
|
||||
)
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_credential import AuthCredentialTypes
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....auth.auth_schemes import AuthSchemeType
|
||||
from ....auth.auth_tool import AuthConfig
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
class AuthPreparationResult(BaseModel):
|
||||
"""Result of the credential preparation process."""
|
||||
|
||||
state: AuthPreparationState
|
||||
auth_scheme: Optional[AuthScheme] = None
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
|
||||
|
||||
class ToolContextCredentialStore:
|
||||
"""Handles storage and retrieval of credentials within a ToolContext."""
|
||||
|
||||
def __init__(self, tool_context: ToolContext):
|
||||
self.tool_context = tool_context
|
||||
|
||||
def get_credential_key(
|
||||
self,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
# no need to prepend temp: namespace, session state is a copy, changes to
|
||||
# it won't be persisted , only changes in event_action.state_delta will be
|
||||
# persisted. temp: namespace will be cleared after current run. but tool
|
||||
# want access token to be there stored across runs
|
||||
|
||||
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
|
||||
|
||||
def get_credential(
|
||||
self,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
) -> Optional[AuthCredential]:
|
||||
if not self.tool_context:
|
||||
return None
|
||||
|
||||
token_key = self.get_credential_key(auth_scheme, auth_credential)
|
||||
# TODO try not to use session state, this looks a hacky way, depend on
|
||||
# session implementation, we don't want session to persist the token,
|
||||
# meanwhile we want the token shared across runs.
|
||||
serialized_credential = self.tool_context.state.get(token_key)
|
||||
if not serialized_credential:
|
||||
return None
|
||||
return AuthCredential.model_validate(serialized_credential)
|
||||
|
||||
def store_credential(
|
||||
self,
|
||||
key: str,
|
||||
auth_credential: Optional[AuthCredential],
|
||||
):
|
||||
if self.tool_context:
|
||||
serializable_credential = jsonable_encoder(
|
||||
auth_credential, exclude_none=True
|
||||
)
|
||||
self.tool_context.state[key] = serializable_credential
|
||||
|
||||
def remove_credential(self, key: str):
|
||||
del self.tool_context.state[key]
|
||||
|
||||
|
||||
class ToolAuthHandler:
|
||||
"""Handles the preparation and exchange of authentication credentials for tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_context: ToolContext,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
|
||||
credential_store: Optional["ToolContextCredentialStore"] = None,
|
||||
):
|
||||
self.tool_context = tool_context
|
||||
self.auth_scheme = (
|
||||
auth_scheme.model_copy(deep=True) if auth_scheme else None
|
||||
)
|
||||
self.auth_credential = (
|
||||
auth_credential.model_copy(deep=True) if auth_credential else None
|
||||
)
|
||||
self.credential_exchanger = (
|
||||
credential_exchanger or AutoAuthCredentialExchanger()
|
||||
)
|
||||
self.credential_store = credential_store
|
||||
self.should_store_credential = True
|
||||
|
||||
@classmethod
|
||||
def from_tool_context(
|
||||
cls,
|
||||
tool_context: ToolContext,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
|
||||
) -> "ToolAuthHandler":
|
||||
"""Creates a ToolAuthHandler instance from a ToolContext."""
|
||||
credential_store = ToolContextCredentialStore(tool_context)
|
||||
return cls(
|
||||
tool_context,
|
||||
auth_scheme,
|
||||
auth_credential,
|
||||
credential_exchanger,
|
||||
credential_store,
|
||||
)
|
||||
|
||||
def _handle_existing_credential(
|
||||
self,
|
||||
) -> Optional[AuthPreparationResult]:
|
||||
"""Checks for and returns an existing, exchanged credential."""
|
||||
if self.credential_store:
|
||||
existing_credential = self.credential_store.get_credential(
|
||||
self.auth_scheme, self.auth_credential
|
||||
)
|
||||
if existing_credential:
|
||||
return AuthPreparationResult(
|
||||
state="done",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=existing_credential,
|
||||
)
|
||||
return None
|
||||
|
||||
def _exchange_credential(
|
||||
self, auth_credential: AuthCredential
|
||||
) -> Optional[AuthPreparationResult]:
|
||||
"""Handles an OpenID Connect authorization response."""
|
||||
|
||||
exchanged_credential = None
|
||||
try:
|
||||
exchanged_credential = self.credential_exchanger.exchange_credential(
|
||||
self.auth_scheme, auth_credential
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to exchange credential: %s", e)
|
||||
return exchanged_credential
|
||||
|
||||
def _store_credential(self, auth_credential: AuthCredential) -> None:
|
||||
"""stores the auth_credential."""
|
||||
|
||||
if self.credential_store:
|
||||
key = self.credential_store.get_credential_key(
|
||||
self.auth_scheme, self.auth_credential
|
||||
)
|
||||
self.credential_store.store_credential(key, auth_credential)
|
||||
|
||||
def _reqeust_credential(self) -> None:
|
||||
"""Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
|
||||
if self.auth_scheme.type_ in (
|
||||
AuthSchemeType.openIdConnect,
|
||||
AuthSchemeType.oauth2,
|
||||
):
|
||||
if not self.auth_credential or not self.auth_credential.oauth2:
|
||||
raise ValueError(
|
||||
f"auth_credential is empty for scheme {self.auth_scheme.type_}."
|
||||
"Please create AuthCredential using OAuth2Auth."
|
||||
)
|
||||
|
||||
if not self.auth_credential.oauth2.client_id:
|
||||
raise AuthCredentialMissingError(
|
||||
"OAuth2 credentials client_id is missing."
|
||||
)
|
||||
|
||||
if not self.auth_credential.oauth2.client_secret:
|
||||
raise AuthCredentialMissingError(
|
||||
"OAuth2 credentials client_secret is missing."
|
||||
)
|
||||
|
||||
self.tool_context.request_credential(
|
||||
AuthConfig(
|
||||
auth_scheme=self.auth_scheme,
|
||||
raw_auth_credential=self.auth_credential,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_auth_response(self) -> AuthCredential:
|
||||
return self.tool_context.get_auth_response(
|
||||
AuthConfig(
|
||||
auth_scheme=self.auth_scheme,
|
||||
raw_auth_credential=self.auth_credential,
|
||||
)
|
||||
)
|
||||
|
||||
def _request_credential(self, auth_config: AuthConfig):
|
||||
if not self.tool_context:
|
||||
return
|
||||
self.tool_context.request_credential(auth_config)
|
||||
|
||||
def prepare_auth_credentials(
|
||||
self,
|
||||
) -> AuthPreparationResult:
|
||||
"""Prepares authentication credentials, handling exchange and user interaction."""
|
||||
|
||||
# no auth is needed
|
||||
if not self.auth_scheme:
|
||||
return AuthPreparationResult(state="done")
|
||||
|
||||
# Check for existing credential.
|
||||
existing_result = self._handle_existing_credential()
|
||||
if existing_result:
|
||||
return existing_result
|
||||
|
||||
# fetch credential from adk framework
|
||||
# Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
|
||||
# multi-step exchange:
|
||||
# client_id , client_secret -> auth_uri -> auth_code -> access_token
|
||||
# -> bearer token
|
||||
# adk framework supports exchange access_token already
|
||||
fetched_credential = self._get_auth_response() or self.auth_credential
|
||||
|
||||
exchanged_credential = self._exchange_credential(fetched_credential)
|
||||
|
||||
if exchanged_credential:
|
||||
self._store_credential(exchanged_credential)
|
||||
return AuthPreparationResult(
|
||||
state="done",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=exchanged_credential,
|
||||
)
|
||||
else:
|
||||
self._reqeust_credential()
|
||||
return AuthPreparationResult(
|
||||
state="pending",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=self.auth_credential,
|
||||
)
|
||||
Reference in New Issue
Block a user