structure saas with tools
This commit is contained in:
@@ -0,0 +1,272 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import SecurityBase
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import AuthCredentialTypes
|
||||
from .auth_credential import OAuth2Auth
|
||||
from .auth_schemes import AuthSchemeType
|
||||
from .auth_schemes import OAuthGrantType
|
||||
from .auth_schemes import OpenIdConnectWithConfig
|
||||
from .auth_tool import AuthConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..sessions.state import State
|
||||
|
||||
try:
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
SUPPORT_TOKEN_EXCHANGE = True
|
||||
except ImportError:
|
||||
SUPPORT_TOKEN_EXCHANGE = False
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
|
||||
def __init__(self, auth_config: AuthConfig):
|
||||
self.auth_config = auth_config
|
||||
|
||||
def exchange_auth_token(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
"""Generates an auth token from the authorization response.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the access token.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token endpoint is not configured in the auth
|
||||
scheme.
|
||||
AuthCredentialMissingError: If the access token cannot be retrieved
|
||||
from the token endpoint.
|
||||
"""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.exchanged_auth_credential
|
||||
if not SUPPORT_TOKEN_EXCHANGE:
|
||||
return auth_credential
|
||||
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
||||
if not hasattr(auth_scheme, "token_endpoint"):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
token_endpoint = auth_scheme.token_endpoint
|
||||
scopes = auth_scheme.scopes
|
||||
elif isinstance(auth_scheme, OAuth2):
|
||||
if (
|
||||
not auth_scheme.flows.authorizationCode
|
||||
or not auth_scheme.flows.authorizationCode.tokenUrl
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
|
||||
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
|
||||
else:
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
if (
|
||||
not auth_credential
|
||||
or not auth_credential.oauth2
|
||||
or not auth_credential.oauth2.client_id
|
||||
or not auth_credential.oauth2.client_secret
|
||||
or auth_credential.oauth2.access_token
|
||||
or auth_credential.oauth2.refresh_token
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
client = OAuth2Session(
|
||||
auth_credential.oauth2.client_id,
|
||||
auth_credential.oauth2.client_secret,
|
||||
scope=" ".join(scopes),
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
state=auth_credential.oauth2.state,
|
||||
)
|
||||
tokens = client.fetch_token(
|
||||
token_endpoint,
|
||||
authorization_response=auth_credential.oauth2.auth_response_uri,
|
||||
code=auth_credential.oauth2.auth_code,
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
)
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
access_token=tokens.get("access_token"),
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
def parse_and_store_auth_response(self, state: State) -> None:
|
||||
|
||||
credential_key = self.get_credential_key()
|
||||
|
||||
state[credential_key] = self.auth_config.exchanged_auth_credential
|
||||
if not isinstance(
|
||||
self.auth_config.auth_scheme, SecurityBase
|
||||
) or self.auth_config.auth_scheme.type_ not in (
|
||||
AuthSchemeType.oauth2,
|
||||
AuthSchemeType.openIdConnect,
|
||||
):
|
||||
return
|
||||
|
||||
state[credential_key] = self.exchange_auth_token()
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not self.auth_scheme:
|
||||
raise ValueError("auth_scheme is empty.")
|
||||
|
||||
def get_auth_response(self, state: State) -> AuthCredential:
|
||||
credential_key = self.get_credential_key()
|
||||
return state.get(credential_key, None)
|
||||
|
||||
def generate_auth_request(self) -> AuthConfig:
|
||||
if not isinstance(
|
||||
self.auth_config.auth_scheme, SecurityBase
|
||||
) or self.auth_config.auth_scheme.type_ not in (
|
||||
AuthSchemeType.oauth2,
|
||||
AuthSchemeType.openIdConnect,
|
||||
):
|
||||
return self.auth_config.model_copy(deep=True)
|
||||
|
||||
# auth_uri already in exchanged credential
|
||||
if (
|
||||
self.auth_config.exchanged_auth_credential
|
||||
and self.auth_config.exchanged_auth_credential.oauth2
|
||||
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
|
||||
):
|
||||
return self.auth_config.model_copy(deep=True)
|
||||
|
||||
# Check if raw_auth_credential exists
|
||||
if not self.auth_config.raw_auth_credential:
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
|
||||
" auth_credential."
|
||||
)
|
||||
|
||||
# Check if oauth2 exists in raw_auth_credential
|
||||
if not self.auth_config.raw_auth_credential.oauth2:
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
|
||||
" auth_credential."
|
||||
)
|
||||
|
||||
# auth_uri in raw credential
|
||||
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
|
||||
return AuthConfig(
|
||||
auth_scheme=self.auth_config.auth_scheme,
|
||||
raw_auth_credential=self.auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
|
||||
deep=True
|
||||
),
|
||||
)
|
||||
|
||||
# Check for client_id and client_secret
|
||||
if (
|
||||
not self.auth_config.raw_auth_credential.oauth2.client_id
|
||||
or not self.auth_config.raw_auth_credential.oauth2.client_secret
|
||||
):
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
|
||||
" client_id and client_secret in auth_credential.oauth2."
|
||||
)
|
||||
|
||||
# Generate new auth URI
|
||||
exchanged_credential = self.generate_auth_uri()
|
||||
return AuthConfig(
|
||||
auth_scheme=self.auth_config.auth_scheme,
|
||||
raw_auth_credential=self.auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=exchanged_credential,
|
||||
)
|
||||
|
||||
def get_credential_key(self) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
if auth_scheme.model_extra:
|
||||
auth_scheme = auth_scheme.model_copy(deep=True)
|
||||
auth_scheme.model_extra.clear()
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
if auth_credential.model_extra:
|
||||
auth_credential = auth_credential.model_copy(deep=True)
|
||||
auth_credential.model_extra.clear()
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"temp:adk_{scheme_name}_{credential_name}"
|
||||
|
||||
def generate_auth_uri(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
"""Generates an response containing the auth uri for user to sign in.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the auth URI and state.
|
||||
|
||||
Raises:
|
||||
ValueError: If the authorization endpoint is not configured in the auth
|
||||
scheme.
|
||||
"""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
|
||||
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
||||
authorization_endpoint = auth_scheme.authorization_endpoint
|
||||
scopes = auth_scheme.scopes
|
||||
else:
|
||||
authorization_endpoint = (
|
||||
auth_scheme.flows.implicit
|
||||
and auth_scheme.flows.implicit.authorizationUrl
|
||||
or auth_scheme.flows.authorizationCode
|
||||
and auth_scheme.flows.authorizationCode.authorizationUrl
|
||||
or auth_scheme.flows.clientCredentials
|
||||
and auth_scheme.flows.clientCredentials.tokenUrl
|
||||
or auth_scheme.flows.password
|
||||
and auth_scheme.flows.password.tokenUrl
|
||||
)
|
||||
scopes = (
|
||||
auth_scheme.flows.implicit
|
||||
and auth_scheme.flows.implicit.scopes
|
||||
or auth_scheme.flows.authorizationCode
|
||||
and auth_scheme.flows.authorizationCode.scopes
|
||||
or auth_scheme.flows.clientCredentials
|
||||
and auth_scheme.flows.clientCredentials.scopes
|
||||
or auth_scheme.flows.password
|
||||
and auth_scheme.flows.password.scopes
|
||||
)
|
||||
scopes = list(scopes.keys())
|
||||
|
||||
client = OAuth2Session(
|
||||
auth_credential.oauth2.client_id,
|
||||
auth_credential.oauth2.client_secret,
|
||||
scope=" ".join(scopes),
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
)
|
||||
uri, state = client.create_authorization_url(
|
||||
url=authorization_endpoint, access_type="offline", prompt="consent"
|
||||
)
|
||||
exchanged_auth_credential = auth_credential.model_copy(deep=True)
|
||||
exchanged_auth_credential.oauth2.auth_uri = uri
|
||||
exchanged_auth_credential.oauth2.state = state
|
||||
|
||||
return exchanged_auth_credential
|
||||
Reference in New Issue
Block a user