mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 12:12:56 -06:00

BREAKING CHANGE: `token` attribute of OAuth2Auth credentials used to be a dict containing both access_token and refresh_token, given that may cause confusions, now we replace it with access_token and refresh_token at top level of the auth credentials PiperOrigin-RevId: 750346172
273 lines
9.3 KiB
Python
273 lines
9.3 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi.openapi.models import OAuth2
|
|
from fastapi.openapi.models import SecurityBase
|
|
|
|
from .auth_credential import AuthCredential
|
|
from .auth_credential import AuthCredentialTypes
|
|
from .auth_credential import OAuth2Auth
|
|
from .auth_schemes import AuthSchemeType
|
|
from .auth_schemes import OAuthGrantType
|
|
from .auth_schemes import OpenIdConnectWithConfig
|
|
from .auth_tool import AuthConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from ..sessions.state import State
|
|
|
|
try:
|
|
from authlib.integrations.requests_client import OAuth2Session
|
|
|
|
SUPPORT_TOKEN_EXCHANGE = True
|
|
except ImportError:
|
|
SUPPORT_TOKEN_EXCHANGE = False
|
|
|
|
|
|
class AuthHandler:
|
|
|
|
def __init__(self, auth_config: AuthConfig):
|
|
self.auth_config = auth_config
|
|
|
|
def exchange_auth_token(
|
|
self,
|
|
) -> AuthCredential:
|
|
"""Generates an auth token from the authorization response.
|
|
|
|
Returns:
|
|
An AuthCredential object containing the access token.
|
|
|
|
Raises:
|
|
ValueError: If the token endpoint is not configured in the auth
|
|
scheme.
|
|
AuthCredentialMissingError: If the access token cannot be retrieved
|
|
from the token endpoint.
|
|
"""
|
|
auth_scheme = self.auth_config.auth_scheme
|
|
auth_credential = self.auth_config.exchanged_auth_credential
|
|
if not SUPPORT_TOKEN_EXCHANGE:
|
|
return auth_credential
|
|
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
|
if not hasattr(auth_scheme, "token_endpoint"):
|
|
return self.auth_config.exchanged_auth_credential
|
|
token_endpoint = auth_scheme.token_endpoint
|
|
scopes = auth_scheme.scopes
|
|
elif isinstance(auth_scheme, OAuth2):
|
|
if (
|
|
not auth_scheme.flows.authorizationCode
|
|
or not auth_scheme.flows.authorizationCode.tokenUrl
|
|
):
|
|
return self.auth_config.exchanged_auth_credential
|
|
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
|
|
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
|
|
else:
|
|
return self.auth_config.exchanged_auth_credential
|
|
|
|
if (
|
|
not auth_credential
|
|
or not auth_credential.oauth2
|
|
or not auth_credential.oauth2.client_id
|
|
or not auth_credential.oauth2.client_secret
|
|
or auth_credential.oauth2.access_token
|
|
or auth_credential.oauth2.refresh_token
|
|
):
|
|
return self.auth_config.exchanged_auth_credential
|
|
|
|
client = OAuth2Session(
|
|
auth_credential.oauth2.client_id,
|
|
auth_credential.oauth2.client_secret,
|
|
scope=" ".join(scopes),
|
|
redirect_uri=auth_credential.oauth2.redirect_uri,
|
|
state=auth_credential.oauth2.state,
|
|
)
|
|
tokens = client.fetch_token(
|
|
token_endpoint,
|
|
authorization_response=auth_credential.oauth2.auth_response_uri,
|
|
code=auth_credential.oauth2.auth_code,
|
|
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
|
)
|
|
|
|
updated_credential = AuthCredential(
|
|
auth_type=AuthCredentialTypes.OAUTH2,
|
|
oauth2=OAuth2Auth(
|
|
access_token=tokens.get("access_token"),
|
|
refresh_token=tokens.get("refresh_token"),
|
|
),
|
|
)
|
|
return updated_credential
|
|
|
|
def parse_and_store_auth_response(self, state: State) -> None:
|
|
|
|
credential_key = self.get_credential_key()
|
|
|
|
state[credential_key] = self.auth_config.exchanged_auth_credential
|
|
if not isinstance(
|
|
self.auth_config.auth_scheme, SecurityBase
|
|
) or self.auth_config.auth_scheme.type_ not in (
|
|
AuthSchemeType.oauth2,
|
|
AuthSchemeType.openIdConnect,
|
|
):
|
|
return
|
|
|
|
state[credential_key] = self.exchange_auth_token()
|
|
|
|
def _validate(self) -> None:
|
|
if not self.auth_scheme:
|
|
raise ValueError("auth_scheme is empty.")
|
|
|
|
def get_auth_response(self, state: State) -> AuthCredential:
|
|
credential_key = self.get_credential_key()
|
|
return state.get(credential_key, None)
|
|
|
|
def generate_auth_request(self) -> AuthConfig:
|
|
if not isinstance(
|
|
self.auth_config.auth_scheme, SecurityBase
|
|
) or self.auth_config.auth_scheme.type_ not in (
|
|
AuthSchemeType.oauth2,
|
|
AuthSchemeType.openIdConnect,
|
|
):
|
|
return self.auth_config.model_copy(deep=True)
|
|
|
|
# auth_uri already in exchanged credential
|
|
if (
|
|
self.auth_config.exchanged_auth_credential
|
|
and self.auth_config.exchanged_auth_credential.oauth2
|
|
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
|
|
):
|
|
return self.auth_config.model_copy(deep=True)
|
|
|
|
# Check if raw_auth_credential exists
|
|
if not self.auth_config.raw_auth_credential:
|
|
raise ValueError(
|
|
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
|
|
" auth_credential."
|
|
)
|
|
|
|
# Check if oauth2 exists in raw_auth_credential
|
|
if not self.auth_config.raw_auth_credential.oauth2:
|
|
raise ValueError(
|
|
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
|
|
" auth_credential."
|
|
)
|
|
|
|
# auth_uri in raw credential
|
|
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
|
|
return AuthConfig(
|
|
auth_scheme=self.auth_config.auth_scheme,
|
|
raw_auth_credential=self.auth_config.raw_auth_credential,
|
|
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
|
|
deep=True
|
|
),
|
|
)
|
|
|
|
# Check for client_id and client_secret
|
|
if (
|
|
not self.auth_config.raw_auth_credential.oauth2.client_id
|
|
or not self.auth_config.raw_auth_credential.oauth2.client_secret
|
|
):
|
|
raise ValueError(
|
|
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
|
|
" client_id and client_secret in auth_credential.oauth2."
|
|
)
|
|
|
|
# Generate new auth URI
|
|
exchanged_credential = self.generate_auth_uri()
|
|
return AuthConfig(
|
|
auth_scheme=self.auth_config.auth_scheme,
|
|
raw_auth_credential=self.auth_config.raw_auth_credential,
|
|
exchanged_auth_credential=exchanged_credential,
|
|
)
|
|
|
|
def get_credential_key(self) -> str:
|
|
"""Generates a unique key for the given auth scheme and credential."""
|
|
auth_scheme = self.auth_config.auth_scheme
|
|
auth_credential = self.auth_config.raw_auth_credential
|
|
if auth_scheme.model_extra:
|
|
auth_scheme = auth_scheme.model_copy(deep=True)
|
|
auth_scheme.model_extra.clear()
|
|
scheme_name = (
|
|
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
|
if auth_scheme
|
|
else ""
|
|
)
|
|
if auth_credential.model_extra:
|
|
auth_credential = auth_credential.model_copy(deep=True)
|
|
auth_credential.model_extra.clear()
|
|
credential_name = (
|
|
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
|
if auth_credential
|
|
else ""
|
|
)
|
|
|
|
return f"temp:adk_{scheme_name}_{credential_name}"
|
|
|
|
def generate_auth_uri(
|
|
self,
|
|
) -> AuthCredential:
|
|
"""Generates an response containing the auth uri for user to sign in.
|
|
|
|
Returns:
|
|
An AuthCredential object containing the auth URI and state.
|
|
|
|
Raises:
|
|
ValueError: If the authorization endpoint is not configured in the auth
|
|
scheme.
|
|
"""
|
|
auth_scheme = self.auth_config.auth_scheme
|
|
auth_credential = self.auth_config.raw_auth_credential
|
|
|
|
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
|
authorization_endpoint = auth_scheme.authorization_endpoint
|
|
scopes = auth_scheme.scopes
|
|
else:
|
|
authorization_endpoint = (
|
|
auth_scheme.flows.implicit
|
|
and auth_scheme.flows.implicit.authorizationUrl
|
|
or auth_scheme.flows.authorizationCode
|
|
and auth_scheme.flows.authorizationCode.authorizationUrl
|
|
or auth_scheme.flows.clientCredentials
|
|
and auth_scheme.flows.clientCredentials.tokenUrl
|
|
or auth_scheme.flows.password
|
|
and auth_scheme.flows.password.tokenUrl
|
|
)
|
|
scopes = (
|
|
auth_scheme.flows.implicit
|
|
and auth_scheme.flows.implicit.scopes
|
|
or auth_scheme.flows.authorizationCode
|
|
and auth_scheme.flows.authorizationCode.scopes
|
|
or auth_scheme.flows.clientCredentials
|
|
and auth_scheme.flows.clientCredentials.scopes
|
|
or auth_scheme.flows.password
|
|
and auth_scheme.flows.password.scopes
|
|
)
|
|
scopes = list(scopes.keys())
|
|
|
|
client = OAuth2Session(
|
|
auth_credential.oauth2.client_id,
|
|
auth_credential.oauth2.client_secret,
|
|
scope=" ".join(scopes),
|
|
redirect_uri=auth_credential.oauth2.redirect_uri,
|
|
)
|
|
uri, state = client.create_authorization_url(
|
|
url=authorization_endpoint, access_type="offline", prompt="consent"
|
|
)
|
|
exchanged_auth_credential = auth_credential.model_copy(deep=True)
|
|
exchanged_auth_credential.oauth2.auth_uri = uri
|
|
exchanged_auth_credential.oauth2.state = state
|
|
|
|
return exchanged_auth_credential
|