Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions

View File

@@ -0,0 +1,21 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .openapi_spec_parser import OpenAPIToolset
from .openapi_spec_parser import RestApiTool
__all__ = [
'OpenAPIToolset',
'RestApiTool',
]

View File

@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import auth_helpers
__all__ = [
'auth_helpers',
]

View File

@@ -0,0 +1,498 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import HTTPBase
from fastapi.openapi.models import HTTPBearer
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OpenIdConnect
from fastapi.openapi.models import Schema
from pydantic import BaseModel
from pydantic import ValidationError
import requests
from ....auth.auth_credential import AuthCredential
from ....auth.auth_credential import AuthCredentialTypes
from ....auth.auth_credential import HttpAuth
from ....auth.auth_credential import HttpCredentials
from ....auth.auth_credential import OAuth2Auth
from ....auth.auth_credential import ServiceAccount
from ....auth.auth_credential import ServiceAccountCredential
from ....auth.auth_schemes import AuthScheme
from ....auth.auth_schemes import AuthSchemeType
from ....auth.auth_schemes import OpenIdConnectWithConfig
from ..common.common import ApiParameter
class OpenIdConfig(BaseModel):
"""Represents OpenID Connect configuration.
Attributes:
client_id: The client ID.
auth_uri: The authorization URI.
token_uri: The token URI.
client_secret: The client secret.
Example:
config = OpenIdConfig(
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
client_secret="your_client_secret",
redirect
)
"""
client_id: str
auth_uri: str
token_uri: str
client_secret: str
redirect_uri: Optional[str]
def token_to_scheme_credential(
token_type: Literal["apikey", "oauth2Token"],
location: Optional[Literal["header", "query", "cookie"]] = None,
name: Optional[str] = None,
credential_value: Optional[str] = None,
) -> Tuple[AuthScheme, AuthCredential]:
"""Creates a AuthScheme and AuthCredential for API key or bearer token.
Examples:
```
# API Key in header
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "header",
"X-API-Key", "your_api_key_value")
# API Key in query parameter
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "query",
"api_key", "your_api_key_value")
# OAuth2 Bearer Token in Authorization header
auth_scheme, auth_credential = token_to_scheme_credential("oauth2Token",
"header", "Authorization", "your_bearer_token_value")
```
Args:
type: 'apikey' or 'oauth2Token'.
location: 'header', 'query', or 'cookie' (only 'header' for oauth2Token).
name: The name of the header, query parameter, or cookie.
credential_value: The value of the API Key/ Token.
Returns:
Tuple: (AuthScheme, AuthCredential)
Raises:
ValueError: For invalid type or location.
"""
if token_type == "apikey":
in_: APIKeyIn
if location == "header":
in_ = APIKeyIn.header
elif location == "query":
in_ = APIKeyIn.query
elif location == "cookie":
in_ = APIKeyIn.cookie
else:
raise ValueError(f"Invalid location for apiKey: {location}")
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": in_,
"name": name,
})
if credential_value:
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key=credential_value
)
else:
auth_credential = None
return auth_scheme, auth_credential
elif token_type == "oauth2Token":
# ignore location. OAuth2 Bearer Token is always in Authorization header.
auth_scheme = HTTPBearer(
bearerFormat="JWT"
) # Common format, can be omitted.
if credential_value:
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credential_value),
),
)
else:
auth_credential = None
return auth_scheme, auth_credential
else:
raise ValueError(f"Invalid security scheme type: {type}")
def service_account_dict_to_scheme_credential(
config: Dict[str, Any],
scopes: List[str],
) -> Tuple[AuthScheme, AuthCredential]:
"""Creates AuthScheme and AuthCredential for Google Service Account.
Returns a bearer token scheme, and a service account credential.
Args:
config: A ServiceAccount object containing the Google Service Account
configuration.
scopes: A list of scopes to be used.
Returns:
Tuple: (AuthScheme, AuthCredential)
"""
auth_scheme = HTTPBearer(bearerFormat="JWT")
service_account = ServiceAccount(
service_account_credential=ServiceAccountCredential.model_construct(
**config
),
scopes=scopes,
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=service_account,
)
return auth_scheme, auth_credential
def service_account_scheme_credential(
config: ServiceAccount,
) -> Tuple[AuthScheme, AuthCredential]:
"""Creates AuthScheme and AuthCredential for Google Service Account.
Returns a bearer token scheme, and a service account credential.
Args:
config: A ServiceAccount object containing the Google Service Account
configuration.
Returns:
Tuple: (AuthScheme, AuthCredential)
"""
auth_scheme = HTTPBearer(bearerFormat="JWT")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=config
)
return auth_scheme, auth_credential
def openid_dict_to_scheme_credential(
config_dict: Dict[str, Any],
scopes: List[str],
credential_dict: Dict[str, Any],
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
"""Constructs OpenID scheme and credential from configuration and credential dictionaries.
Args:
config_dict: Dictionary containing OpenID Connect configuration, must
include at least 'authorization_endpoint' and 'token_endpoint'.
scopes: List of scopes to be used.
credential_dict: Dictionary containing credential information, must
include 'client_id', 'client_secret', and 'scopes'. May optionally
include 'redirect_uri'.
Returns:
Tuple: (OpenIdConnectWithConfig, AuthCredential)
Raises:
ValueError: If required fields are missing in the input dictionaries.
"""
# Validate and create the OpenIdConnectWithConfig scheme
try:
config_dict["scopes"] = scopes
# If user provides the OpenID Config as a static dict, it may not contain
# openIdConnect URL.
if "openIdConnectUrl" not in config_dict:
config_dict["openIdConnectUrl"] = ""
openid_scheme = OpenIdConnectWithConfig.model_validate(config_dict)
except ValidationError as e:
raise ValueError(f"Invalid OpenID Connect configuration: {e}") from e
# Attempt to adjust credential_dict if this is a key downloaded from Google
# OAuth config
if len(list(credential_dict.values())) == 1:
credential_value = list(credential_dict.values())[0]
if "client_id" in credential_value and "client_secret" in credential_value:
credential_dict = credential_value
# Validate credential_dict
required_credential_fields = ["client_id", "client_secret"]
missing_fields = [
field
for field in required_credential_fields
if field not in credential_dict
]
if missing_fields:
raise ValueError(
"Missing required fields in credential_dict:"
f" {', '.join(missing_fields)}"
)
# Construct AuthCredential
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id=credential_dict["client_id"],
client_secret=credential_dict["client_secret"],
redirect_uri=credential_dict.get("redirect_uri", None),
),
)
return openid_scheme, auth_credential
def openid_url_to_scheme_credential(
openid_url: str, scopes: List[str], credential_dict: Dict[str, Any]
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
"""Constructs OpenID scheme and credential from OpenID URL, scopes, and credential dictionary.
Fetches OpenID configuration from the provided URL.
Args:
openid_url: The OpenID Connect discovery URL.
scopes: List of scopes to be used.
credential_dict: Dictionary containing credential information, must
include at least "client_id" and "client_secret", may optionally include
"redirect_uri" and "scope"
Returns:
Tuple: (AuthScheme, AuthCredential)
Raises:
ValueError: If the OpenID URL is invalid, fetching fails, or required
fields are missing.
requests.exceptions.RequestException: If there's an error during the
HTTP request.
"""
try:
response = requests.get(openid_url, timeout=10)
response.raise_for_status()
config_dict = response.json()
except requests.exceptions.RequestException as e:
raise ValueError(
f"Failed to fetch OpenID configuration from {openid_url}: {e}"
) from e
except ValueError as e:
raise ValueError(
"Invalid JSON response from OpenID configuration endpoint"
f" {openid_url}: {e}"
) from e
# Add openIdConnectUrl to config dict
config_dict["openIdConnectUrl"] = openid_url
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
INTERNAL_AUTH_PREFIX = "_auth_prefix_vaf_"
def credential_to_param(
auth_scheme: AuthScheme,
auth_credential: AuthCredential,
) -> Tuple[Optional[ApiParameter], Optional[Dict[str, Any]]]:
"""Converts AuthCredential and AuthScheme to a Parameter and a dictionary for additional kwargs.
This function now supports all credential types returned by the exchangers:
- API Key
- HTTP Bearer (for Bearer tokens, OAuth2, Service Account, OpenID Connect)
- OAuth2 and OpenID Connect (returns None, None, as the token is now a Bearer
token)
- Service Account (returns None, None, as the token is now a Bearer token)
Args:
auth_scheme: The AuthScheme object.
auth_credential: The AuthCredential object.
Returns:
Tuple: (ApiParameter, Dict[str, Any])
"""
if not auth_credential:
return None, None
if (
auth_scheme.type_ == AuthSchemeType.apiKey
and auth_credential
and auth_credential.api_key
):
param_name = auth_scheme.name or ""
python_name = INTERNAL_AUTH_PREFIX + param_name
if auth_scheme.in_ == APIKeyIn.header:
param_location = "header"
elif auth_scheme.in_ == APIKeyIn.query:
param_location = "query"
elif auth_scheme.in_ == APIKeyIn.cookie:
param_location = "cookie"
else:
raise ValueError(f"Invalid API Key location: {auth_scheme.in_}")
param = ApiParameter(
original_name=param_name,
param_location=param_location,
param_schema=Schema(type="string"),
description=auth_scheme.description or "",
py_name=python_name,
)
kwargs = {param.py_name: auth_credential.api_key}
return param, kwargs
# TODO(cheliu): Split handling for OpenIDConnect scheme and native HTTPBearer
# Scheme
elif (
auth_credential and auth_credential.auth_type == AuthCredentialTypes.HTTP
):
if (
auth_credential
and auth_credential.http
and auth_credential.http.credentials
and auth_credential.http.credentials.token
):
param = ApiParameter(
original_name="Authorization",
param_location="header",
param_schema=Schema(type="string"),
description=auth_scheme.description or "Bearer token",
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
)
kwargs = {
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
}
return param, kwargs
elif (
auth_credential
and auth_credential.http
and auth_credential.http.credentials
and (
auth_credential.http.credentials.username
or auth_credential.http.credentials.password
)
):
# Basic Auth is explicitly NOT supported
raise NotImplementedError("Basic Authentication is not supported.")
else:
raise ValueError("Invalid HTTP auth credentials")
# Service Account tokens, OAuth2 Tokens and OpenID Tokens are now handled as
# Bearer tokens.
elif (auth_scheme.type_ == AuthSchemeType.oauth2 and auth_credential) or (
auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_credential
):
if (
auth_credential.http
and auth_credential.http.credentials
and auth_credential.http.credentials.token
):
param = ApiParameter(
original_name="Authorization",
param_location="header",
param_schema=Schema(type="string"),
description=auth_scheme.description or "Bearer token",
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
)
kwargs = {
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
}
return param, kwargs
return None, None
else:
raise ValueError("Invalid security scheme and credential combination")
def dict_to_auth_scheme(data: Dict[str, Any]) -> AuthScheme:
"""Converts a dictionary to a FastAPI AuthScheme object.
Args:
data: The dictionary representing the security scheme.
Returns:
A AuthScheme object (APIKey, HTTPBase, OAuth2, OpenIdConnect, or
HTTPBearer).
Raises:
ValueError: If the 'type' field is missing or invalid, or if the
dictionary cannot be converted to the corresponding Pydantic model.
Example:
```python
api_key_data = {
"type": "apiKey",
"in": "header",
"name": "X-API-Key",
}
api_key_scheme = dict_to_auth_scheme(api_key_data)
bearer_data = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
}
bearer_scheme = dict_to_auth_scheme(bearer_data)
oauth2_data = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://example.com/auth",
"tokenUrl": "https://example.com/token",
}
}
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data)
openid_data = {
"type": "openIdConnect",
"openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
}
openid_scheme = dict_to_auth_scheme(openid_data)
```
"""
if "type" not in data:
raise ValueError("Missing 'type' field in security scheme dictionary.")
security_type = data["type"]
try:
if security_type == "apiKey":
return APIKey.model_validate(data)
elif security_type == "http":
if data.get("scheme") == "bearer":
return HTTPBearer.model_validate(data)
else:
return HTTPBase.model_validate(data) # Generic HTTP
elif security_type == "oauth2":
return OAuth2.model_validate(data)
elif security_type == "openIdConnect":
return OpenIdConnect.model_validate(data)
else:
raise ValueError(f"Invalid security scheme type: {security_type}")
except ValidationError as e:
raise ValueError(f"Invalid security scheme data: {e}") from e

View File

@@ -0,0 +1,25 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from .base_credential_exchanger import BaseAuthCredentialExchanger
from .oauth2_exchanger import OAuth2CredentialExchanger
from .service_account_exchanger import ServiceAccountCredentialExchanger
__all__ = [
'AutoAuthCredentialExchanger',
'BaseAuthCredentialExchanger',
'OAuth2CredentialExchanger',
'ServiceAccountCredentialExchanger',
]

View File

@@ -0,0 +1,105 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Optional
from typing import Type
from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
from .....auth.auth_schemes import AuthScheme
from .base_credential_exchanger import BaseAuthCredentialExchanger
from .oauth2_exchanger import OAuth2CredentialExchanger
from .service_account_exchanger import ServiceAccountCredentialExchanger
class AutoAuthCredentialExchanger(BaseAuthCredentialExchanger):
"""Automatically selects the appropriate credential exchanger based on the auth scheme.
Optionally, an override can be provided to use a specific exchanger for a
given auth scheme.
Example (common case):
```
exchanger = AutoAuthCredentialExchanger()
auth_credential = exchanger.exchange_credential(
auth_scheme=service_account_scheme,
auth_credential=service_account_credential,
)
# Returns an oauth token in the form of a bearer token.
```
Example (use CustomAuthExchanger for OAuth2):
```
exchanger = AutoAuthCredentialExchanger(
custom_exchangers={
AuthScheme.OAUTH2: CustomAuthExchanger,
}
)
```
Attributes:
exchangers: A dictionary mapping auth scheme to credential exchanger class.
"""
def __init__(
self,
custom_exchangers: Optional[
Dict[str, Type[BaseAuthCredentialExchanger]]
] = None,
):
"""Initializes the AutoAuthCredentialExchanger.
Args:
custom_exchangers: Optional dictionary for adding or overriding auth
exchangers. The key is the auth scheme, and the value is the credential
exchanger class.
"""
self.exchangers = {
AuthCredentialTypes.OAUTH2: OAuth2CredentialExchanger,
AuthCredentialTypes.OPEN_ID_CONNECT: OAuth2CredentialExchanger,
AuthCredentialTypes.SERVICE_ACCOUNT: ServiceAccountCredentialExchanger,
}
if custom_exchangers:
self.exchangers.update(custom_exchangers)
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> Optional[AuthCredential]:
"""Automatically exchanges for the credential uses the appropriate credential exchanger.
Args:
auth_scheme (AuthScheme): The security scheme.
auth_credential (AuthCredential): Optional. The authentication
credential.
Returns: (AuthCredential)
A new AuthCredential object containing the exchanged credential.
"""
if not auth_credential:
return None
exchanger_class = self.exchangers.get(
auth_credential.auth_type if auth_credential else None
)
if not exchanger_class:
return auth_credential
exchanger = exchanger_class()
return exchanger.exchange_credential(auth_scheme, auth_credential)

View File

@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Optional
from .....auth.auth_credential import (
AuthCredential,
)
from .....auth.auth_schemes import AuthScheme
class AuthCredentialMissingError(Exception):
"""Exception raised when required authentication credentials are missing."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
class BaseAuthCredentialExchanger:
"""Base class for authentication credential exchangers."""
@abc.abstractmethod
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the provided authentication credential for a usable token/credential.
Args:
auth_scheme: The security scheme.
auth_credential: The authentication credential.
Returns:
An updated AuthCredential object containing the fetched credential.
For simple schemes like API key, it may return the original credential
if no exchange is needed.
Raises:
NotImplementedError: If the method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement exchange_credential.")

View File

@@ -0,0 +1,117 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Credential fetcher for OpenID Connect."""
from typing import Optional
from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
from .....auth.auth_credential import HttpAuth
from .....auth.auth_credential import HttpCredentials
from .....auth.auth_schemes import AuthScheme
from .....auth.auth_schemes import AuthSchemeType
from .base_credential_exchanger import BaseAuthCredentialExchanger
class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
"""Fetches credentials for OAuth2 and OpenID Connect."""
def _check_scheme_credential_type(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
):
if not auth_credential:
raise ValueError(
"auth_credential is empty. Please create AuthCredential using"
" OAuth2Auth."
)
if auth_scheme.type_ not in (
AuthSchemeType.openIdConnect,
AuthSchemeType.oauth2,
):
raise ValueError(
"Invalid security scheme, expect AuthSchemeType.openIdConnect or "
f"AuthSchemeType.oauth2 auth scheme, but got {auth_scheme.type_}"
)
if not auth_credential.oauth2 and not auth_credential.http:
raise ValueError(
"auth_credential is not configured with oauth2. Please"
" create AuthCredential and set OAuth2Auth."
)
def generate_auth_token(
self,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Generates an auth token from the authorization response.
Args:
auth_scheme: The OpenID Connect or OAuth2 auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential object containing the HTTP bearer access token. If the
HTTO bearer token cannot be generated, return the origianl credential
"""
if "access_token" not in auth_credential.oauth2.token:
return auth_credential
# Return the access token as a bearer token.
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(
token=auth_credential.oauth2.token["access_token"]
),
),
)
return updated_credential
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the OpenID Connect auth credential for an access token or an auth URI.
Args:
auth_scheme: The auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential object containing the HTTP Bearer access token.
Raises:
ValueError: If the auth scheme or auth credential is invalid.
"""
# TODO(cheliu): Implement token refresh flow
self._check_scheme_credential_type(auth_scheme, auth_credential)
# If token is already HTTPBearer token, do nothing assuming that this token
# is valid.
if auth_credential.http:
return auth_credential
# If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.token:
return self.generate_auth_token(auth_credential)
return None

View File

@@ -0,0 +1,97 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Credential fetcher for Google Service Account."""
from typing import Optional
import google.auth
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import google.oauth2.credentials
from .....auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
HttpAuth,
HttpCredentials,
)
from .....auth.auth_schemes import AuthScheme
from .base_credential_exchanger import AuthCredentialMissingError, BaseAuthCredentialExchanger
class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
"""Fetches credentials for Google Service Account.
Uses the default service credential if `use_default_credential = True`.
Otherwise, uses the service account credential provided in the auth
credential.
"""
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the service account auth credential for an access token.
If auth_credential contains a service account credential, it will be used
to fetch an access token. Otherwise, the default service credential will be
used for fetching an access token.
Args:
auth_scheme: The auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential in HTTPBearer format, containing the access token.
"""
if (
auth_credential is None
or auth_credential.service_account is None
or (
auth_credential.service_account.service_account_credential is None
and not auth_credential.service_account.use_default_credential
)
):
raise AuthCredentialMissingError(
"Service account credentials are missing. Please provide them, or set"
" `use_default_credential = True` to use application default"
" credential in a hosted service like Cloud Run."
)
try:
if auth_credential.service_account.use_default_credential:
credentials, _ = google.auth.default()
else:
config = auth_credential.service_account
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(), scopes=config.scopes
)
credentials.refresh(Request())
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credentials.token),
),
)
return updated_credential
except Exception as e:
raise AuthCredentialMissingError(
f"Failed to exchange service account token: {e}"
) from e

View File

@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import common
__all__ = [
'common',
]

View File

@@ -0,0 +1,300 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keyword
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from fastapi.openapi.models import Response
from fastapi.openapi.models import Schema
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_serializer
def to_snake_case(text: str) -> str:
"""Converts a string into snake_case.
Handles lowerCamelCase, UpperCamelCase, or space-separated case, acronyms
(e.g., "REST API") and consecutive uppercase letters correctly. Also handles
mixed cases with and without spaces.
Examples:
```
to_snake_case('camelCase') -> 'camel_case'
to_snake_case('UpperCamelCase') -> 'upper_camel_case'
to_snake_case('space separated') -> 'space_separated'
```
Args:
text: The input string.
Returns:
The snake_case version of the string.
"""
# Handle spaces and non-alphanumeric characters (replace with underscores)
text = re.sub(r'[^a-zA-Z0-9]+', '_', text)
# Insert underscores before uppercase letters (handling both CamelCases)
text = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', text) # lowerCamelCase
text = re.sub(
r'([A-Z]+)([A-Z][a-z])', r'\1_\2', text
) # UpperCamelCase and acronyms
# Convert to lowercase
text = text.lower()
# Remove consecutive underscores (clean up extra underscores)
text = re.sub(r'_+', '_', text)
# Remove leading and trailing underscores
text = text.strip('_')
return text
def rename_python_keywords(s: str, prefix: str = 'param_') -> str:
"""Renames Python keywords by adding a prefix.
Example:
```
rename_python_keywords('if') -> 'param_if'
rename_python_keywords('for') -> 'param_for'
```
Args:
s: The input string.
prefix: The prefix to add to the keyword.
Returns:
The renamed string.
"""
if keyword.iskeyword(s):
return prefix + s
return s
class ApiParameter(BaseModel):
"""Data class representing a function parameter."""
original_name: str
param_location: str
param_schema: Union[str, Schema]
description: Optional[str] = ''
py_name: Optional[str] = ''
type_value: type[Any] = Field(default=None, init_var=False)
type_hint: str = Field(default=None, init_var=False)
def model_post_init(self, _: Any):
self.py_name = (
self.py_name
if self.py_name
else rename_python_keywords(to_snake_case(self.original_name))
)
if isinstance(self.param_schema, str):
self.param_schema = Schema.model_validate_json(self.param_schema)
self.description = self.description or self.param_schema.description or ''
self.type_value = TypeHintHelper.get_type_value(self.param_schema)
self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
return self
@model_serializer
def _serialize(self):
return {
'original_name': self.original_name,
'param_location': self.param_location,
'param_schema': self.param_schema,
'description': self.description,
'py_name': self.py_name,
}
def __str__(self):
return f'{self.py_name}: {self.type_hint}'
def to_arg_string(self):
"""Converts the parameter to an argument string for function call."""
return f'{self.py_name}={self.py_name}'
def to_dict_property(self):
"""Converts the parameter to a key:value string for dict property."""
return f'"{self.py_name}": {self.py_name}'
def to_pydoc_string(self):
"""Converts the parameter to a PyDoc parameter docstr."""
return PydocHelper.generate_param_doc(self)
class TypeHintHelper:
"""Helper class for generating type hints."""
@staticmethod
def get_type_value(schema: Schema) -> Any:
"""Generates the Python type value for a given parameter."""
param_type = schema.type if schema.type else Any
if param_type == 'integer':
return int
elif param_type == 'number':
return float
elif param_type == 'boolean':
return bool
elif param_type == 'string':
return str
elif param_type == 'array':
items_type = Any
if schema.items and schema.items.type:
items_type = schema.items.type
if items_type == 'object':
return List[Dict[str, Any]]
else:
type_map = {
'integer': int,
'number': float,
'boolean': bool,
'string': str,
'object': Dict[str, Any],
'array': List[Any],
}
return List[type_map.get(items_type, 'Any')]
elif param_type == 'object':
return Dict[str, Any]
else:
return Any
@staticmethod
def get_type_hint(schema: Schema) -> str:
"""Generates the Python type in string for a given parameter."""
param_type = schema.type if schema.type else 'Any'
if param_type == 'integer':
return 'int'
elif param_type == 'number':
return 'float'
elif param_type == 'boolean':
return 'bool'
elif param_type == 'string':
return 'str'
elif param_type == 'array':
items_type = 'Any'
if schema.items and schema.items.type:
items_type = schema.items.type
if items_type == 'object':
return 'List[Dict[str, Any]]'
else:
type_map = {
'integer': 'int',
'number': 'float',
'boolean': 'bool',
'string': 'str',
}
return f"List[{type_map.get(items_type, 'Any')}]"
elif param_type == 'object':
return 'Dict[str, Any]'
else:
return 'Any'
class PydocHelper:
"""Helper class for generating PyDoc strings."""
@staticmethod
def generate_param_doc(
param: ApiParameter,
) -> str:
"""Generates a parameter documentation string.
Args:
param: ApiParameter - The parameter to generate the documentation for.
Returns:
str: The generated parameter Python documentation string.
"""
description = param.description.strip() if param.description else ''
param_doc = f'{param.py_name} ({param.type_hint}): {description}'
if param.param_schema.type == 'object':
properties = param.param_schema.properties
if properties:
param_doc += ' Object properties:\n'
for prop_name, prop_details in properties.items():
prop_desc = prop_details.description or ''
prop_type = TypeHintHelper.get_type_hint(prop_details)
param_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
return param_doc
@staticmethod
def generate_return_doc(responses: Dict[str, Response]) -> str:
"""Generates a return value documentation string.
Args:
responses: Dict[str, TypedDict[Response]] - Response in an OpenAPI
Operation
Returns:
str: The generated return value Python documentation string.
"""
return_doc = ''
# Only consider 2xx responses for return type hinting.
# Returns the 2xx response with the smallest status code number and with
# content defined.
sorted_responses = sorted(responses.items(), key=lambda item: int(item[0]))
qualified_response = next(
filter(
lambda r: r[0].startswith('2') and r[1].content,
sorted_responses,
),
None,
)
if not qualified_response:
return ''
response_details = qualified_response[1]
description = (response_details.description or '').strip()
content = response_details.content or {}
# Generate return type hint and properties for the first response type.
# TODO(cheliu): Handle multiple content types.
for _, schema_details in content.items():
schema = schema_details.schema_ or {}
# Use a dummy Parameter object for return type hinting.
dummy_param = ApiParameter(
original_name='', param_location='', param_schema=schema
)
return_doc = f'Returns ({dummy_param.type_hint}): {description}'
response_type = schema.type or 'Any'
if response_type != 'object':
break
properties = schema.properties
if not properties:
break
return_doc += ' Object properties:\n'
for prop_name, prop_details in properties.items():
prop_desc = prop_details.description or ''
prop_type = TypeHintHelper.get_type_hint(prop_details)
return_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
break
return return_doc

View File

@@ -0,0 +1,32 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
from .openapi_toolset import OpenAPIToolset
from .operation_parser import OperationParser
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
from .tool_auth_handler import ToolAuthHandler
__all__ = [
'OpenApiSpecParser',
'OperationEndpoint',
'ParsedOperation',
'OpenAPIToolset',
'OperationParser',
'RestApiTool',
'to_gemini_schema',
'snake_to_lower_camel',
'AuthPreparationState',
'ToolAuthHandler',
]

View File

@@ -0,0 +1,231 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from fastapi.openapi.models import Operation
from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .operation_parser import OperationParser
class OperationEndpoint(BaseModel):
base_url: str
path: str
method: str
class ParsedOperation(BaseModel):
name: str
description: str
endpoint: OperationEndpoint
operation: Operation
parameters: List[ApiParameter]
return_value: ApiParameter
auth_scheme: Optional[AuthScheme] = None
auth_credential: Optional[AuthCredential] = None
additional_context: Optional[Any] = None
class OpenApiSpecParser:
"""Generates Python code, JSON schema, and callables for an OpenAPI operation.
This class takes an OpenApiOperation object and provides methods to generate:
1. A string representation of a Python function that handles the operation.
2. A JSON schema representing the input parameters of the operation.
3. A callable Python object (a function) that can execute the operation.
"""
def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
"""Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
ParsedOperation objects are further used for generating RestApiTool.
Args:
openapi_spec_dict: A dictionary representing the OpenAPI specification.
Returns:
A list of ParsedOperation objects.
"""
openapi_spec_dict = self._resolve_references(openapi_spec_dict)
operations = self._collect_operations(openapi_spec_dict)
return operations
def _collect_operations(
self, openapi_spec: Dict[str, Any]
) -> List[ParsedOperation]:
"""Collects operations from an OpenAPI spec."""
operations = []
# Taking first server url, or default to empty string if not present
base_url = ""
if openapi_spec.get("servers"):
base_url = openapi_spec["servers"][0].get("url", "")
# Get global security scheme (if any)
global_scheme_name = None
if openapi_spec.get("security"):
# Use first scheme by default.
scheme_names = list(openapi_spec["security"][0].keys())
global_scheme_name = scheme_names[0] if scheme_names else None
auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
for path, path_item in openapi_spec.get("paths", {}).items():
if path_item is None:
continue
for method in (
"get",
"post",
"put",
"delete",
"patch",
"head",
"options",
"trace",
):
operation_dict = path_item.get(method)
if operation_dict is None:
continue
# If operation ID is missing, assign an operation id based on path
# and method
if "operationId" not in operation_dict:
temp_id = to_snake_case(f"{path}_{method}")
operation_dict["operationId"] = temp_id
url = OperationEndpoint(base_url=base_url, path=path, method=method)
operation = Operation.model_validate(operation_dict)
operation_parser = OperationParser(operation)
# Check for operation-specific auth scheme
auth_scheme_name = operation_parser.get_auth_scheme_name()
auth_scheme_name = (
auth_scheme_name if auth_scheme_name else global_scheme_name
)
auth_scheme = (
auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
)
parsed_op = ParsedOperation(
name=operation_parser.get_function_name(),
description=operation.description or operation.summary or "",
endpoint=url,
operation=operation,
parameters=operation_parser.get_parameters(),
return_value=operation_parser.get_return_value(),
auth_scheme=auth_scheme,
auth_credential=None, # Placeholder
additional_context={},
)
operations.append(parsed_op)
return operations
def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively resolves all $ref references in an OpenAPI specification.
Handles circular references correctly.
Args:
openapi_spec: A dictionary representing the OpenAPI specification.
Returns:
A dictionary representing the OpenAPI specification with all references
resolved.
"""
openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
resolved_cache = {} # Cache resolved references
def resolve_ref(ref_string, current_doc):
"""Resolves a single $ref string."""
parts = ref_string.split("/")
if parts[0] != "#":
raise ValueError(f"External references not supported: {ref_string}")
current = current_doc
for part in parts[1:]:
if part in current:
current = current[part]
else:
return None # Reference not found
return current
def recursive_resolve(obj, current_doc, seen_refs=None):
"""Recursively resolves references, handling circularity.
Args:
obj: The object to traverse.
current_doc: Document to search for refs.
seen_refs: A set to track already-visited references (for circularity
detection).
Returns:
The resolved object.
"""
if seen_refs is None:
seen_refs = set() # Initialize the set if it's the first call
if isinstance(obj, dict):
if "$ref" in obj and isinstance(obj["$ref"], str):
ref_string = obj["$ref"]
# Check for circularity
if ref_string in seen_refs and ref_string not in resolved_cache:
# Circular reference detected! Return a *copy* of the object,
# but *without* the $ref. This breaks the cycle while
# still maintaining the overall structure.
return {k: v for k, v in obj.items() if k != "$ref"}
seen_refs.add(ref_string) # Add the reference to the set
# Check if we have a cached resolved value
if ref_string in resolved_cache:
return copy.deepcopy(resolved_cache[ref_string])
resolved_value = resolve_ref(ref_string, current_doc)
if resolved_value is not None:
# Recursively resolve the *resolved* value,
# passing along the 'seen_refs' set
resolved_value = recursive_resolve(
resolved_value, current_doc, seen_refs
)
resolved_cache[ref_string] = resolved_value
return copy.deepcopy(resolved_value) # return the cached result
else:
return obj # return original if no resolved value.
else:
new_dict = {}
for key, value in obj.items():
new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
return new_dict
elif isinstance(obj, list):
return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
else:
return obj
return recursive_resolve(openapi_spec, openapi_spec)

View File

@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any
from typing import Dict
from typing import Final
from typing import List
from typing import Literal
from typing import Optional
import yaml
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from .openapi_spec_parser import OpenApiSpecParser
from .rest_api_tool import RestApiTool
logger = logging.getLogger(__name__)
class OpenAPIToolset:
"""Class for parsing OpenAPI spec into a list of RestApiTool.
Usage:
```
# Initialize OpenAPI toolset from a spec string.
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
spec_str_type="json")
# Or, initialize OpenAPI toolset from a spec dictionary.
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
# Add all tools to an agent.
agent = Agent(
tools=[*openapi_toolset.get_tools()]
)
# Or, add a single tool to an agent.
agent = Agent(
tools=[openapi_toolset.get_tool('tool_name')]
)
```
"""
def __init__(
self,
*,
spec_dict: Optional[Dict[str, Any]] = None,
spec_str: Optional[str] = None,
spec_str_type: Literal["json", "yaml"] = "json",
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
):
"""Initializes the OpenAPIToolset.
Usage:
```
# Initialize OpenAPI toolset from a spec string.
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
spec_str_type="json")
# Or, initialize OpenAPI toolset from a spec dictionary.
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
# Add all tools to an agent.
agent = Agent(
tools=[*openapi_toolset.get_tools()]
)
# Or, add a single tool to an agent.
agent = Agent(
tools=[openapi_toolset.get_tool('tool_name')]
)
```
Args:
spec_dict: The OpenAPI spec dictionary. If provided, it will be used
instead of loading the spec from a string.
spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
when spec_dict is not provided.
spec_str_type: The type of the OpenAPI spec string. Can be "json" or
"yaml".
auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
auth_credential: The auth credential to use for all tools. Use
AuthCredential or use helpers in
`google.adk.tools.openapi_tool.auth.auth_helpers`
"""
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
):
"""Configure auth scheme and credential for all tools."""
for tool in self.tools:
if auth_scheme:
tool.configure_auth_scheme(auth_scheme)
if auth_credential:
tool.configure_auth_credential(auth_credential)
def get_tools(self) -> List[RestApiTool]:
"""Get all tools in the toolset."""
return self.tools
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
return next(matching_tool, None)
def _load_spec(
self, spec_str: str, spec_type: Literal["json", "yaml"]
) -> Dict[str, Any]:
"""Loads the OpenAPI spec string into adictionary."""
if spec_type == "json":
return json.loads(spec_str)
elif spec_type == "yaml":
return yaml.safe_load(spec_str)
else:
raise ValueError(f"Unsupported spec type: {spec_type}")
def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
"""Parse OpenAPI spec into a list of RestApiTool."""
operations = OpenApiSpecParser().parse(openapi_spec_dict)
tools = []
for o in operations:
tool = RestApiTool.from_parsed_operation(o)
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
return tools

View File

@@ -0,0 +1,260 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from textwrap import dedent
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from ..common.common import ApiParameter
from ..common.common import PydocHelper
from ..common.common import to_snake_case
class OperationParser:
"""Generates parameters for Python functions from an OpenAPI operation.
This class processes an OpenApiOperation object and provides helper methods
to extract information needed to generate Python function declarations,
docstrings, signatures, and JSON schemas. It handles parameter processing,
name deduplication, and type hint generation.
"""
def __init__(
self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
):
"""Initializes the OperationParser with an OpenApiOperation.
Args:
operation: The OpenApiOperation object or a dictionary to process.
should_parse: Whether to parse the operation during initialization.
"""
if isinstance(operation, dict):
self.operation = Operation.model_validate(operation)
elif isinstance(operation, str):
self.operation = Operation.model_validate_json(operation)
else:
self.operation = operation
self.params: List[ApiParameter] = []
self.return_value: Optional[ApiParameter] = None
if should_parse:
self._process_operation_parameters()
self._process_request_body()
self._process_return_value()
self._dedupe_param_names()
@classmethod
def load(
cls,
operation: Union[Operation, Dict[str, Any]],
params: List[ApiParameter],
return_value: Optional[ApiParameter] = None,
) -> 'OperationParser':
parser = cls(operation, should_parse=False)
parser.params = params
parser.return_value = return_value
return parser
def _process_operation_parameters(self):
"""Processes parameters from the OpenAPI operation."""
parameters = self.operation.parameters or []
for param in parameters:
if isinstance(param, Parameter):
original_name = param.name
description = param.description or ''
location = param.in_ or ''
schema = param.schema_ or {} # Use schema_ instead of .schema
self.params.append(
ApiParameter(
original_name=original_name,
param_location=location,
param_schema=schema,
description=description,
)
)
def _process_request_body(self):
"""Processes the request body from the OpenAPI operation."""
request_body = self.operation.requestBody
if not request_body:
return
content = request_body.content or {}
if not content:
return
# If request body is an object, expand the properties as parameters
for _, media_type_object in content.items():
schema = media_type_object.schema_ or {}
description = request_body.description or ''
if schema and schema.type == 'object':
for prop_name, prop_details in schema.properties.items():
self.params.append(
ApiParameter(
original_name=prop_name,
param_location='body',
param_schema=prop_details,
description=prop_details.description,
)
)
elif schema and schema.type == 'array':
self.params.append(
ApiParameter(
original_name='array',
param_location='body',
param_schema=schema,
description=description,
)
)
else:
self.params.append(
# Empty name for unnamed body param
ApiParameter(
original_name='',
param_location='body',
param_schema=schema,
description=description,
)
)
break # Process first mime type only
def _dedupe_param_names(self):
"""Deduplicates parameter names to avoid conflicts."""
params_cnt = {}
for param in self.params:
name = param.py_name
if name not in params_cnt:
params_cnt[name] = 0
else:
params_cnt[name] += 1
param.py_name = f'{name}_{params_cnt[name] -1}'
def _process_return_value(self) -> Parameter:
"""Returns a Parameter object representing the return type."""
responses = self.operation.responses or {}
# Default to Any if no 2xx response or if schema is missing
return_schema = Schema(type='Any')
# Take the 20x response with the smallest response code.
valid_codes = list(
filter(lambda k: k.startswith('2'), list(responses.keys()))
)
min_20x_status_code = min(valid_codes) if valid_codes else None
if min_20x_status_code and responses[min_20x_status_code].content:
content = responses[min_20x_status_code].content
for mime_type in content:
if content[mime_type].schema_:
return_schema = content[mime_type].schema_
break
self.return_value = ApiParameter(
original_name='',
param_location='',
param_schema=return_schema,
)
def get_function_name(self) -> str:
"""Returns the generated function name."""
operation_id = self.operation.operationId
if not operation_id:
raise ValueError('Operation ID is missing')
return to_snake_case(operation_id)[:60]
def get_return_type_hint(self) -> str:
"""Returns the return type hint string (like 'str', 'int', etc.)."""
return self.return_value.type_hint
def get_return_type_value(self) -> Any:
"""Returns the return type value (like str, int, List[str], etc.)."""
return self.return_value.type_value
def get_parameters(self) -> List[ApiParameter]:
"""Returns the list of Parameter objects."""
return self.params
def get_return_value(self) -> ApiParameter:
"""Returns the list of Parameter objects."""
return self.return_value
def get_auth_scheme_name(self) -> str:
"""Returns the name of the auth scheme for this operation from the spec."""
if self.operation.security:
scheme_name = list(self.operation.security[0].keys())[0]
return scheme_name
return ''
def get_pydoc_string(self) -> str:
"""Returns the generated PyDoc string."""
pydoc_params = [param.to_pydoc_string() for param in self.params]
pydoc_description = (
self.operation.summary or self.operation.description or ''
)
pydoc_return = PydocHelper.generate_return_doc(
self.operation.responses or {}
)
pydoc_arg_list = chr(10).join(
f' {param_doc}' for param_doc in pydoc_params
)
return dedent(f"""
\"\"\"{pydoc_description}
Args:
{pydoc_arg_list}
{pydoc_return}
\"\"\"
""").strip()
def get_json_schema(self) -> Dict[str, Any]:
"""Returns the JSON schema for the function arguments."""
properties = {
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
for p in self.params
}
return {
'properties': properties,
'required': [p.py_name for p in self.params],
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
'type': 'object',
}
def get_signature_parameters(self) -> List[inspect.Parameter]:
"""Returns a list of inspect.Parameter objects for the function."""
return [
inspect.Parameter(
param.py_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=param.type_value,
)
for param in self.params
]
def get_annotations(self) -> Dict[str, Any]:
"""Returns a dictionary of parameter annotations for the function."""
annotations = {p.py_name: p.type_value for p in self.params}
annotations['return'] = self.get_return_type_value()
return annotations

View File

@@ -0,0 +1,496 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
from fastapi.openapi.models import Operation
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
import requests
from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .openapi_spec_parser import OperationEndpoint
from .openapi_spec_parser import ParsedOperation
from .operation_parser import OperationParser
from .tool_auth_handler import ToolAuthHandler
def snake_to_lower_camel(snake_case_string: str):
"""Converts a snake_case string to a lower_camel_case string.
Args:
snake_case_string: The input snake_case string.
Returns:
The lower_camel_case string.
"""
if "_" not in snake_case_string:
return snake_case_string
return "".join([
s.lower() if i == 0 else s.capitalize()
for i, s in enumerate(snake_case_string.split("_"))
])
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
Args:
openapi_schema: The OpenAPI schema dictionary.
Returns:
A Pydantic Schema object. Returns None if input is None.
Raises TypeError if input is not a dict.
"""
if openapi_schema is None:
return None
if not isinstance(openapi_schema, dict):
raise TypeError("openapi_schema must be a dictionary")
pydantic_schema_data = {}
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
if not openapi_schema.get("type"):
openapi_schema["type"] = "object"
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
# See b/385165182
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
"properties"
):
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields.
if snake_case_key in Schema.model_fields:
if snake_case_key in ["title", "default", "format"]:
# Ignore these fields as Gemini backend doesn't recognize them, and will
# throw exception if they appear in the schema.
# Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type
continue
if snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items()
}
elif snake_case_key == "items" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
elif snake_case_key == "any_of" and isinstance(value, list):
pydantic_schema_data[snake_case_key] = [
to_gemini_schema(item) for item in value
]
# Important: Handle cases where the OpenAPI schema might contain lists
# or other structures that need to be recursively processed.
elif isinstance(value, list) and snake_case_key not in (
"enum",
"required",
"property_ordering",
):
new_list = []
for item in value:
if isinstance(item, dict):
new_list.append(to_gemini_schema(item))
else:
new_list.append(item)
pydantic_schema_data[snake_case_key] = new_list
elif isinstance(value, dict) and snake_case_key not in ("properties"):
# Handle dictionary which is neither properties or items
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
else:
# Simple value assignment (int, str, bool, etc.)
pydantic_schema_data[snake_case_key] = value
return Schema(**pydantic_schema_data)
AuthPreparationState = Literal["pending", "done"]
class RestApiTool(BaseTool):
"""A generic tool that interacts with a REST API.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
def __init__(
self,
name: str,
description: str,
endpoint: Union[OperationEndpoint, str],
operation: Union[Operation, str],
auth_scheme: Optional[Union[AuthScheme, str]] = None,
auth_credential: Optional[Union[AuthCredential, str]] = None,
should_parse_operation=True,
):
"""Initializes the RestApiTool with the given parameters.
To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
auth_scheme and auth_credential.
Args:
name: The name of the tool.
description: The description of the tool.
endpoint: Include the base_url, path, and method of the tool.
operation: Pydantic object or a dict. Representing the OpenAPI Operation
object
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
auth_scheme: The auth scheme of the tool. Representing the OpenAPI
SecurityScheme object
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
auth_credential: The authentication credential of the tool.
should_parse_operation: Whether to parse the operation.
"""
# Gemini restrict the length of function name to be less than 64 characters
self.name = name[:60]
self.description = description
self.endpoint = (
OperationEndpoint.model_validate_json(endpoint)
if isinstance(endpoint, str)
else endpoint
)
self.operation = (
Operation.model_validate_json(operation)
if isinstance(operation, str)
else operation
)
self.auth_credential, self.auth_scheme = None, None
self.configure_auth_credential(auth_credential)
self.configure_auth_scheme(auth_scheme)
# Private properties
self.credential_exchanger = AutoAuthCredentialExchanger()
if should_parse_operation:
self._operation_parser = OperationParser(self.operation)
@classmethod
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
"""Initializes the RestApiTool from a ParsedOperation object.
Args:
parsed: A ParsedOperation object.
Returns:
A RestApiTool object.
"""
operation_parser = OperationParser.load(
parsed.operation, parsed.parameters, parsed.return_value
)
tool_name = to_snake_case(operation_parser.get_function_name())
generated = cls(
name=tool_name,
description=parsed.operation.description
or parsed.operation.summary
or "",
endpoint=parsed.endpoint,
operation=parsed.operation,
auth_scheme=parsed.auth_scheme,
auth_credential=parsed.auth_credential,
)
generated._operation_parser = operation_parser
return generated
@classmethod
def from_parsed_operation_str(
cls, parsed_operation_str: str
) -> "RestApiTool":
"""Initializes the RestApiTool from a dict.
Args:
parsed: A dict representation of a ParsedOperation object.
Returns:
A RestApiTool object.
"""
operation = ParsedOperation.model_validate_json(parsed_operation_str)
return RestApiTool.from_parsed_operation(operation)
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._operation_parser.get_json_schema()
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
def configure_auth_scheme(
self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
):
"""Configures the authentication scheme for the API call.
Args:
auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
converted to a AuthScheme object.
"""
if isinstance(auth_scheme, dict):
auth_scheme = dict_to_auth_scheme(auth_scheme)
self.auth_scheme = auth_scheme
def configure_auth_credential(
self, auth_credential: Optional[Union[AuthCredential, str]] = None
):
"""Configures the authentication credential for the API call.
Args:
auth_credential: AuthCredential|dict - The authentication credential.
The dict is converted to an AuthCredential object.
"""
if isinstance(auth_credential, str):
auth_credential = AuthCredential.model_validate_json(auth_credential)
self.auth_credential = auth_credential
def _prepare_auth_request_params(
self,
auth_scheme: AuthScheme,
auth_credential: AuthCredential,
) -> Tuple[List[ApiParameter], Dict[str, Any]]:
# Handle Authentication
if not auth_scheme or not auth_credential:
return
return credential_to_param(auth_scheme, auth_credential)
def _prepare_request_params(
self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepares the request parameters for the API call.
Args:
parameters: A list of ApiParameter objects representing the parameters
for the API call.
kwargs: The keyword arguments passed to the call function from the Tool
caller.
Returns:
A dictionary containing the request parameters for the API call. This
initializes a requests.request() call.
Example:
self._prepare_request_params({"input_id": "test-id"})
"""
method = self.endpoint.method.lower()
if not method:
raise ValueError("Operation method not found.")
path_params: Dict[str, Any] = {}
query_params: Dict[str, Any] = {}
header_params: Dict[str, Any] = {}
cookie_params: Dict[str, Any] = {}
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
# Fill in path, query, header and cookie parameters to the request
for param_k, v in kwargs.items():
param_obj = params_map.get(param_k)
if not param_obj:
continue # If input arg not in the ApiParameter list, ignore it.
original_k = param_obj.original_name
param_location = param_obj.param_location
if param_location == "path":
path_params[original_k] = v
elif param_location == "query":
if v:
query_params[original_k] = v
elif param_location == "header":
header_params[original_k] = v
elif param_location == "cookie":
cookie_params[original_k] = v
# Construct URL
base_url = self.endpoint.base_url or ""
base_url = base_url[:-1] if base_url.endswith("/") else base_url
url = f"{base_url}{self.endpoint.path.format(**path_params)}"
# Construct body
body_kwargs: Dict[str, Any] = {}
request_body = self.operation.requestBody
if request_body:
for mime_type, media_type_object in request_body.content.items():
schema = media_type_object.schema_
body_data = None
if schema.type == "object":
body_data = {}
for param in parameters:
if param.param_location == "body" and param.py_name in kwargs:
body_data[param.original_name] = kwargs[param.py_name]
elif schema.type == "array":
for param in parameters:
if param.param_location == "body" and param.py_name == "array":
body_data = kwargs.get("array")
break
else: # like string
for param in parameters:
# original_name = '' indicating this param applies to the full body.
if param.param_location == "body" and not param.original_name:
body_data = (
kwargs.get(param.py_name) if param.py_name in kwargs else None
)
break
if mime_type == "application/json" or mime_type.endswith("+json"):
if body_data is not None:
body_kwargs["json"] = body_data
elif mime_type == "application/x-www-form-urlencoded":
body_kwargs["data"] = body_data
elif mime_type == "multipart/form-data":
body_kwargs["files"] = body_data
elif mime_type == "application/octet-stream":
body_kwargs["data"] = body_data
elif mime_type == "text/plain":
body_kwargs["data"] = body_data
if mime_type:
header_params["Content-Type"] = mime_type
break # Process only the first mime_type
filtered_query_params: Dict[str, Any] = {
k: v for k, v in query_params.items() if v is not None
}
request_params: Dict[str, Any] = {
"method": method,
"url": url,
"params": filtered_query_params,
"headers": header_params,
"cookies": cookie_params,
**body_kwargs,
}
return request_params
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
return self.call(args=args, tool_context=tool_context)
def call(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
"""Executes the REST API call.
Args:
args: Keyword arguments representing the operation parameters.
tool_context: The tool context (not used here, but required by the
interface).
Returns:
The API response as a dictionary.
"""
# Prepare auth credentials for the API call
tool_auth_handler = ToolAuthHandler.from_tool_context(
tool_context, self.auth_scheme, self.auth_credential
)
auth_result = tool_auth_handler.prepare_auth_credentials()
auth_state, auth_scheme, auth_credential = (
auth_result.state,
auth_result.auth_scheme,
auth_result.auth_credential,
)
if auth_state == "pending":
return {
"pending": True,
"message": "Needs your authorization to access your data.",
}
# Attach parameters from auth into main parameters list
api_params, api_args = self._operation_parser.get_parameters().copy(), args
if auth_credential:
# Attach parameters from auth into main parameters list
auth_param, auth_args = self._prepare_auth_request_params(
auth_scheme, auth_credential
)
if auth_param and auth_args:
api_params = [auth_param] + api_params
api_args.update(auth_args)
# Got all parameters. Call the API.
request_params = self._prepare_request_params(api_params, api_args)
response = requests.request(**request_params)
# Parse API response
try:
response.raise_for_status() # Raise HTTPError for bad responses
return response.json() # Try to decode JSON
except requests.exceptions.HTTPError:
error_details = response.content.decode("utf-8")
return {
"error": (
f"Tool {self.name} execution failed. Analyze this execution error"
" and your inputs. Retry with adjustments if applicable. But"
" make sure don't retry more than 3 times. Execution Error:"
f" {error_details}"
)
}
except ValueError:
return {"text": response.text} # Return text if not JSON
def __str__(self):
return (
f'RestApiTool(name="{self.name}", description="{self.description}",'
f' endpoint="{self.endpoint}")'
)
def __repr__(self):
return (
f'RestApiTool(name="{self.name}", description="{self.description}",'
f' endpoint="{self.endpoint}", operation="{self.operation}",'
f' auth_scheme="{self.auth_scheme}",'
f' auth_credential="{self.auth_credential}")'
)

View File

@@ -0,0 +1,268 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Literal
from typing import Optional
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_credential import AuthCredentialTypes
from ....auth.auth_schemes import AuthScheme
from ....auth.auth_schemes import AuthSchemeType
from ....auth.auth_tool import AuthConfig
from ...tool_context import ToolContext
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
logger = logging.getLogger(__name__)
AuthPreparationState = Literal["pending", "done"]
class AuthPreparationResult(BaseModel):
"""Result of the credential preparation process."""
state: AuthPreparationState
auth_scheme: Optional[AuthScheme] = None
auth_credential: Optional[AuthCredential] = None
class ToolContextCredentialStore:
"""Handles storage and retrieval of credentials within a ToolContext."""
def __init__(self, tool_context: ToolContext):
self.tool_context = tool_context
def get_credential_key(
self,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
) -> str:
"""Generates a unique key for the given auth scheme and credential."""
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
# no need to prepend temp: namespace, session state is a copy, changes to
# it won't be persisted , only changes in event_action.state_delta will be
# persisted. temp: namespace will be cleared after current run. but tool
# want access token to be there stored across runs
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
def get_credential(
self,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
) -> Optional[AuthCredential]:
if not self.tool_context:
return None
token_key = self.get_credential_key(auth_scheme, auth_credential)
# TODO try not to use session state, this looks a hacky way, depend on
# session implementation, we don't want session to persist the token,
# meanwhile we want the token shared across runs.
serialized_credential = self.tool_context.state.get(token_key)
if not serialized_credential:
return None
return AuthCredential.model_validate(serialized_credential)
def store_credential(
self,
key: str,
auth_credential: Optional[AuthCredential],
):
if self.tool_context:
serializable_credential = jsonable_encoder(
auth_credential, exclude_none=True
)
self.tool_context.state[key] = serializable_credential
def remove_credential(self, key: str):
del self.tool_context.state[key]
class ToolAuthHandler:
"""Handles the preparation and exchange of authentication credentials for tools."""
def __init__(
self,
tool_context: ToolContext,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
credential_store: Optional["ToolContextCredentialStore"] = None,
):
self.tool_context = tool_context
self.auth_scheme = (
auth_scheme.model_copy(deep=True) if auth_scheme else None
)
self.auth_credential = (
auth_credential.model_copy(deep=True) if auth_credential else None
)
self.credential_exchanger = (
credential_exchanger or AutoAuthCredentialExchanger()
)
self.credential_store = credential_store
self.should_store_credential = True
@classmethod
def from_tool_context(
cls,
tool_context: ToolContext,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
) -> "ToolAuthHandler":
"""Creates a ToolAuthHandler instance from a ToolContext."""
credential_store = ToolContextCredentialStore(tool_context)
return cls(
tool_context,
auth_scheme,
auth_credential,
credential_exchanger,
credential_store,
)
def _handle_existing_credential(
self,
) -> Optional[AuthPreparationResult]:
"""Checks for and returns an existing, exchanged credential."""
if self.credential_store:
existing_credential = self.credential_store.get_credential(
self.auth_scheme, self.auth_credential
)
if existing_credential:
return AuthPreparationResult(
state="done",
auth_scheme=self.auth_scheme,
auth_credential=existing_credential,
)
return None
def _exchange_credential(
self, auth_credential: AuthCredential
) -> Optional[AuthPreparationResult]:
"""Handles an OpenID Connect authorization response."""
exchanged_credential = None
try:
exchanged_credential = self.credential_exchanger.exchange_credential(
self.auth_scheme, auth_credential
)
except Exception as e:
logger.error("Failed to exchange credential: %s", e)
return exchanged_credential
def _store_credential(self, auth_credential: AuthCredential) -> None:
"""stores the auth_credential."""
if self.credential_store:
key = self.credential_store.get_credential_key(
self.auth_scheme, self.auth_credential
)
self.credential_store.store_credential(key, auth_credential)
def _reqeust_credential(self) -> None:
"""Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
if self.auth_scheme.type_ in (
AuthSchemeType.openIdConnect,
AuthSchemeType.oauth2,
):
if not self.auth_credential or not self.auth_credential.oauth2:
raise ValueError(
f"auth_credential is empty for scheme {self.auth_scheme.type_}."
"Please create AuthCredential using OAuth2Auth."
)
if not self.auth_credential.oauth2.client_id:
raise AuthCredentialMissingError(
"OAuth2 credentials client_id is missing."
)
if not self.auth_credential.oauth2.client_secret:
raise AuthCredentialMissingError(
"OAuth2 credentials client_secret is missing."
)
self.tool_context.request_credential(
AuthConfig(
auth_scheme=self.auth_scheme,
raw_auth_credential=self.auth_credential,
)
)
return None
def _get_auth_response(self) -> AuthCredential:
return self.tool_context.get_auth_response(
AuthConfig(
auth_scheme=self.auth_scheme,
raw_auth_credential=self.auth_credential,
)
)
def _request_credential(self, auth_config: AuthConfig):
if not self.tool_context:
return
self.tool_context.request_credential(auth_config)
def prepare_auth_credentials(
self,
) -> AuthPreparationResult:
"""Prepares authentication credentials, handling exchange and user interaction."""
# no auth is needed
if not self.auth_scheme:
return AuthPreparationResult(state="done")
# Check for existing credential.
existing_result = self._handle_existing_credential()
if existing_result:
return existing_result
# fetch credential from adk framework
# Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
# multi-step exchange:
# client_id , client_secret -> auth_uri -> auth_code -> access_token
# -> bearer token
# adk framework supports exchange access_token already
fetched_credential = self._get_auth_response() or self.auth_credential
exchanged_credential = self._exchange_credential(fetched_credential)
if exchanged_credential:
self._store_credential(exchanged_credential)
return AuthPreparationResult(
state="done",
auth_scheme=self.auth_scheme,
auth_credential=exchanged_credential,
)
else:
self._reqeust_credential()
return AuthPreparationResult(
state="pending",
auth_scheme=self.auth_scheme,
auth_credential=self.auth_credential,
)