adk-python/src/google/adk/auth/auth_handler.py
Xiang (Sean) Zhou 956fb912e8 feat(auth)!: expose access_token and refresh_token at top level of auth credentails
BREAKING CHANGE: `token` attribute of OAuth2Auth credentials used to be a dict containing both access_token and refresh_token, given that may cause confusions, now we replace it with access_token and refresh_token at top level of the auth credentials

PiperOrigin-RevId: 750346172
2025-04-22 15:23:21 -07:00

273 lines
9.3 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import SecurityBase
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
from .auth_credential import OAuth2Auth
from .auth_schemes import AuthSchemeType
from .auth_schemes import OAuthGrantType
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
if TYPE_CHECKING:
from ..sessions.state import State
try:
from authlib.integrations.requests_client import OAuth2Session
SUPPORT_TOKEN_EXCHANGE = True
except ImportError:
SUPPORT_TOKEN_EXCHANGE = False
class AuthHandler:
def __init__(self, auth_config: AuthConfig):
self.auth_config = auth_config
def exchange_auth_token(
self,
) -> AuthCredential:
"""Generates an auth token from the authorization response.
Returns:
An AuthCredential object containing the access token.
Raises:
ValueError: If the token endpoint is not configured in the auth
scheme.
AuthCredentialMissingError: If the access token cannot be retrieved
from the token endpoint.
"""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.exchanged_auth_credential
if not SUPPORT_TOKEN_EXCHANGE:
return auth_credential
if isinstance(auth_scheme, OpenIdConnectWithConfig):
if not hasattr(auth_scheme, "token_endpoint"):
return self.auth_config.exchanged_auth_credential
token_endpoint = auth_scheme.token_endpoint
scopes = auth_scheme.scopes
elif isinstance(auth_scheme, OAuth2):
if (
not auth_scheme.flows.authorizationCode
or not auth_scheme.flows.authorizationCode.tokenUrl
):
return self.auth_config.exchanged_auth_credential
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
else:
return self.auth_config.exchanged_auth_credential
if (
not auth_credential
or not auth_credential.oauth2
or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret
or auth_credential.oauth2.access_token
or auth_credential.oauth2.refresh_token
):
return self.auth_config.exchanged_auth_credential
client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
)
tokens = client.fetch_token(
token_endpoint,
authorization_response=auth_credential.oauth2.auth_response_uri,
code=auth_credential.oauth2.auth_code,
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
)
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
access_token=tokens.get("access_token"),
refresh_token=tokens.get("refresh_token"),
),
)
return updated_credential
def parse_and_store_auth_response(self, state: State) -> None:
credential_key = self.get_credential_key()
state[credential_key] = self.auth_config.exchanged_auth_credential
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
AuthSchemeType.oauth2,
AuthSchemeType.openIdConnect,
):
return
state[credential_key] = self.exchange_auth_token()
def _validate(self) -> None:
if not self.auth_scheme:
raise ValueError("auth_scheme is empty.")
def get_auth_response(self, state: State) -> AuthCredential:
credential_key = self.get_credential_key()
return state.get(credential_key, None)
def generate_auth_request(self) -> AuthConfig:
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
AuthSchemeType.oauth2,
AuthSchemeType.openIdConnect,
):
return self.auth_config.model_copy(deep=True)
# auth_uri already in exchanged credential
if (
self.auth_config.exchanged_auth_credential
and self.auth_config.exchanged_auth_credential.oauth2
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
):
return self.auth_config.model_copy(deep=True)
# Check if raw_auth_credential exists
if not self.auth_config.raw_auth_credential:
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
" auth_credential."
)
# Check if oauth2 exists in raw_auth_credential
if not self.auth_config.raw_auth_credential.oauth2:
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
" auth_credential."
)
# auth_uri in raw credential
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
return AuthConfig(
auth_scheme=self.auth_config.auth_scheme,
raw_auth_credential=self.auth_config.raw_auth_credential,
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
deep=True
),
)
# Check for client_id and client_secret
if (
not self.auth_config.raw_auth_credential.oauth2.client_id
or not self.auth_config.raw_auth_credential.oauth2.client_secret
):
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
" client_id and client_secret in auth_credential.oauth2."
)
# Generate new auth URI
exchanged_credential = self.generate_auth_uri()
return AuthConfig(
auth_scheme=self.auth_config.auth_scheme,
raw_auth_credential=self.auth_config.raw_auth_credential,
exchanged_auth_credential=exchanged_credential,
)
def get_credential_key(self) -> str:
"""Generates a unique key for the given auth scheme and credential."""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if auth_scheme.model_extra:
auth_scheme = auth_scheme.model_copy(deep=True)
auth_scheme.model_extra.clear()
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
if auth_credential.model_extra:
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
return f"temp:adk_{scheme_name}_{credential_name}"
def generate_auth_uri(
self,
) -> AuthCredential:
"""Generates an response containing the auth uri for user to sign in.
Returns:
An AuthCredential object containing the auth URI and state.
Raises:
ValueError: If the authorization endpoint is not configured in the auth
scheme.
"""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if isinstance(auth_scheme, OpenIdConnectWithConfig):
authorization_endpoint = auth_scheme.authorization_endpoint
scopes = auth_scheme.scopes
else:
authorization_endpoint = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.authorizationUrl
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.authorizationUrl
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.tokenUrl
or auth_scheme.flows.password
and auth_scheme.flows.password.tokenUrl
)
scopes = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.scopes
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.scopes
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.scopes
or auth_scheme.flows.password
and auth_scheme.flows.password.scopes
)
scopes = list(scopes.keys())
client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
uri, state = client.create_authorization_url(
url=authorization_endpoint, access_type="offline", prompt="consent"
)
exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
exchanged_auth_credential.oauth2.state = state
return exchanged_auth_credential