Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

View File

@@ -0,0 +1,145 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for AutoAuthCredentialExchanger."""
from typing import Dict
from typing import Optional
from typing import Type
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.oauth2_exchanger import OAuth2CredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import pytest
class MockCredentialExchanger(BaseAuthCredentialExchanger):
"""Mock credential exchanger for testing."""
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Mock exchange credential method."""
return auth_credential
@pytest.fixture
def auto_exchanger():
"""Fixture for creating an AutoAuthCredentialExchanger instance."""
return AutoAuthCredentialExchanger()
@pytest.fixture
def auth_scheme():
"""Fixture for creating a mock AuthScheme instance."""
scheme = MagicMock(spec=AuthScheme)
return scheme
def test_init_with_custom_exchangers():
"""Test initialization with custom exchangers."""
custom_exchangers: Dict[str, Type[BaseAuthCredentialExchanger]] = {
AuthCredentialTypes.API_KEY: MockCredentialExchanger
}
auto_exchanger = AutoAuthCredentialExchanger(
custom_exchangers=custom_exchangers
)
assert (
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY]
== MockCredentialExchanger
)
assert (
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT]
== OAuth2CredentialExchanger
)
def test_exchange_credential_no_auth_credential(auto_exchanger, auth_scheme):
"""Test exchange_credential with no auth_credential."""
assert auto_exchanger.exchange_credential(auth_scheme, None) is None
def test_exchange_credential_no_exchange(auto_exchanger, auth_scheme):
"""Test exchange_credential with NoExchangeCredentialExchanger."""
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == auth_credential
def test_exchange_credential_open_id_connect(auto_exchanger, auth_scheme):
"""Test exchange_credential with OpenID Connect scheme."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT
)
mock_exchanger = MagicMock(spec=OAuth2CredentialExchanger)
mock_exchanger.exchange_credential.return_value = "exchanged_credential"
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT] = (
lambda: mock_exchanger
)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == "exchanged_credential"
mock_exchanger.exchange_credential.assert_called_once_with(
auth_scheme, auth_credential
)
def test_exchange_credential_service_account(auto_exchanger, auth_scheme):
"""Test exchange_credential with Service Account scheme."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT
)
mock_exchanger = MagicMock(spec=ServiceAccountCredentialExchanger)
mock_exchanger.exchange_credential.return_value = "exchanged_credential_sa"
auto_exchanger.exchangers[AuthCredentialTypes.SERVICE_ACCOUNT] = (
lambda: mock_exchanger
)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == "exchanged_credential_sa"
mock_exchanger.exchange_credential.assert_called_once_with(
auth_scheme, auth_credential
)
def test_exchange_credential_custom_exchanger(auto_exchanger, auth_scheme):
"""Test that exchange_credential calls the correct (custom) exchanger."""
# Use a custom exchanger via the initialization
mock_exchanger = MagicMock(spec=MockCredentialExchanger)
mock_exchanger.exchange_credential.return_value = "custom_credential"
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY] = (
lambda: mock_exchanger
)
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == "custom_credential"
mock_exchanger.exchange_credential.assert_called_once_with(
auth_scheme, auth_credential
)

View File

@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the BaseAuthCredentialExchanger class."""
from typing import Optional
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
import pytest
class MockAuthCredentialExchanger(BaseAuthCredentialExchanger):
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
return AuthCredential(token="some-token")
class TestBaseAuthCredentialExchanger:
"""Tests for the BaseAuthCredentialExchanger class."""
@pytest.fixture
def base_exchanger(self):
return BaseAuthCredentialExchanger()
@pytest.fixture
def auth_scheme(self):
scheme = MagicMock(spec=AuthScheme)
scheme.type = "apiKey"
scheme.name = "x-api-key"
return scheme
def test_exchange_credential_not_implemented(
self, base_exchanger, auth_scheme
):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, token="some-token"
)
with pytest.raises(NotImplementedError) as exc_info:
base_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Subclasses must implement exchange_credential." in str(
exc_info.value
)
def test_auth_credential_missing_error(self):
error_message = "Test missing credential"
error = AuthCredentialMissingError(error_message)
# assert error.message == error_message
assert str(error) == error_message

View File

@@ -0,0 +1,153 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for OAuth2CredentialExchanger."""
import copy
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.tools.openapi_tool.auth.credential_exchangers import OAuth2CredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
import pytest
@pytest.fixture
def oauth2_exchanger():
return OAuth2CredentialExchanger()
@pytest.fixture
def auth_scheme():
openid_config = OpenIdConnectWithConfig(
type_=AuthSchemeType.openIdConnect,
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token",
scopes=["openid", "profile"],
)
return openid_config
def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
),
)
# Check that the method does not raise an exception
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
def test_check_scheme_credential_type_missing_credential(
oauth2_exchanger, auth_scheme
):
# Test case: auth_credential is None
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
assert "auth_credential is empty" in str(exc_info.value)
def test_check_scheme_credential_type_invalid_scheme_type(
oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
):
"""Test case: Invalid AuthSchemeType."""
# Test case: Invalid AuthSchemeType
invalid_scheme = copy.deepcopy(auth_scheme)
invalid_scheme.type_ = AuthSchemeType.apiKey
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
),
)
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(
invalid_scheme, auth_credential
)
assert "Invalid security scheme" in str(exc_info.value)
def test_check_scheme_credential_type_missing_openid_connect(
oauth2_exchanger, auth_scheme
):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
)
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
assert "auth_credential is not configured with oauth2" in str(exc_info.value)
def test_generate_auth_token_success(
oauth2_exchanger, auth_scheme, monkeypatch
):
"""Test case: Successful generation of access token."""
# Test case: Successful generation of access token
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"},
),
)
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "test_access_token"
def test_exchange_credential_generate_auth_token(
oauth2_exchanger, auth_scheme, monkeypatch
):
"""Test exchange_credential when auth_response_uri is present."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"},
),
)
updated_credential = oauth2_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "test_access_token"
def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
"""Test exchange_credential when auth_credential is missing."""
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger.exchange_credential(auth_scheme, None)
assert "auth_credential is empty. Please create AuthCredential using" in str(
exc_info.value
)

View File

@@ -0,0 +1,196 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the service account credential exchanger."""
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_credential import ServiceAccountCredential
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import google.auth
import pytest
@pytest.fixture
def service_account_exchanger():
return ServiceAccountCredentialExchanger()
@pytest.fixture
def auth_scheme():
scheme = MagicMock(spec=AuthScheme)
scheme.type_ = AuthSchemeType.oauth2
scheme.description = "Google Service Account"
return scheme
def test_exchange_credential_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange of service account credentials."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
# Mock the from_service_account_info method
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
monkeypatch.setattr(
target_path,
mock_from_service_account_info,
)
# Mock the refresh method
mock_credentials.refresh = MagicMock()
# Create a valid AuthCredential with service account info
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
mock_from_service_account_info.assert_called_once()
mock_credentials.refresh.assert_called_once()
def test_exchange_credential_use_default_credential_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange of service account credentials using default credential."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
mock_google_auth_default = MagicMock(
return_value=(mock_credentials, "test_project")
)
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
mock_google_auth_default.assert_called_once()
mock_credentials.refresh.assert_called_once()
def test_exchange_credential_missing_auth_credential(
service_account_exchanger, auth_scheme
):
"""Test missing auth credential during exchange."""
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, None)
assert "Service account credentials are missing" in str(exc_info.value)
def test_exchange_credential_missing_service_account_info(
service_account_exchanger, auth_scheme
):
"""Test missing service account info during exchange."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Service account credentials are missing" in str(exc_info.value)
def test_exchange_credential_exchange_failure(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test failure during service account token exchange."""
mock_from_service_account_info = MagicMock(
side_effect=Exception("Failed to load credentials")
)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
monkeypatch.setattr(
target_path,
mock_from_service_account_info,
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account token" in str(exc_info.value)
mock_from_service_account_info.assert_called_once()

View File

@@ -0,0 +1,573 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import HTTPBase
from fastapi.openapi.models import HTTPBearer
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OpenIdConnect
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_credential import ServiceAccountCredential
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.tools.openapi_tool.auth.auth_helpers import credential_to_param
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
from google.adk.tools.openapi_tool.auth.auth_helpers import INTERNAL_AUTH_PREFIX
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_url_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_dict_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
import pytest
import requests
def test_token_to_scheme_credential_api_key_header():
scheme, credential = token_to_scheme_credential(
"apikey", "header", "X-API-Key", "test_key"
)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.header
assert scheme.name == "X-API-Key"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
def test_token_to_scheme_credential_api_key_query():
scheme, credential = token_to_scheme_credential(
"apikey", "query", "api_key", "test_key"
)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.query
assert scheme.name == "api_key"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
def test_token_to_scheme_credential_api_key_cookie():
scheme, credential = token_to_scheme_credential(
"apikey", "cookie", "session_id", "test_key"
)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.cookie
assert scheme.name == "session_id"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
def test_token_to_scheme_credential_api_key_no_credential():
scheme, credential = token_to_scheme_credential(
"apikey", "cookie", "session_id"
)
assert isinstance(scheme, APIKey)
assert credential is None
def test_token_to_scheme_credential_oauth2_token():
scheme, credential = token_to_scheme_credential(
"oauth2Token", "header", "Authorization", "test_token"
)
assert isinstance(scheme, HTTPBearer)
assert scheme.bearerFormat == "JWT"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
def test_token_to_scheme_credential_oauth2_no_credential():
scheme, credential = token_to_scheme_credential(
"oauth2Token", "header", "Authorization"
)
assert isinstance(scheme, HTTPBearer)
assert credential is None
def test_service_account_dict_to_scheme_credential():
config = {
"type": "service_account",
"project_id": "project_id",
"private_key_id": "private_key_id",
"private_key": "private_key",
"client_email": "client_email",
"client_id": "client_id",
"auth_uri": "auth_uri",
"token_uri": "token_uri",
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
"client_x509_cert_url": "client_x509_cert_url",
"universe_domain": "universe_domain",
}
scopes = ["scope1", "scope2"]
scheme, credential = service_account_dict_to_scheme_credential(config, scopes)
assert isinstance(scheme, HTTPBearer)
assert scheme.bearerFormat == "JWT"
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
assert credential.service_account.scopes == scopes
assert (
credential.service_account.service_account_credential.project_id
== "project_id"
)
def test_service_account_scheme_credential():
config = ServiceAccount(
service_account_credential=ServiceAccountCredential(
type="service_account",
project_id="project_id",
private_key_id="private_key_id",
private_key="private_key",
client_email="client_email",
client_id="client_id",
auth_uri="auth_uri",
token_uri="token_uri",
auth_provider_x509_cert_url="auth_provider_x509_cert_url",
client_x509_cert_url="client_x509_cert_url",
universe_domain="universe_domain",
),
scopes=["scope1", "scope2"],
)
scheme, credential = service_account_scheme_credential(config)
assert isinstance(scheme, HTTPBearer)
assert scheme.bearerFormat == "JWT"
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
assert credential.service_account == config
def test_openid_dict_to_scheme_credential():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"openIdConnectUrl": "openid_url",
}
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_dict_to_scheme_credential(
config_dict, scopes, credential_dict
)
assert isinstance(scheme, OpenIdConnectWithConfig)
assert scheme.authorization_endpoint == "auth_url"
assert scheme.token_endpoint == "token_url"
assert scheme.scopes == scopes
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
assert credential.oauth2.client_id == "client_id"
assert credential.oauth2.client_secret == "client_secret"
assert credential.oauth2.redirect_uri == "redirect_uri"
def test_openid_dict_to_scheme_credential_no_openid_url():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
}
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_dict_to_scheme_credential(
config_dict, scopes, credential_dict
)
assert scheme.openIdConnectUrl == ""
def test_openid_dict_to_scheme_credential_google_oauth_credential():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"openIdConnectUrl": "openid_url",
}
credential_dict = {
"web": {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_dict_to_scheme_credential(
config_dict, scopes, credential_dict
)
assert isinstance(scheme, OpenIdConnectWithConfig)
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
assert credential.oauth2.client_id == "client_id"
assert credential.oauth2.client_secret == "client_secret"
assert credential.oauth2.redirect_uri == "redirect_uri"
def test_openid_dict_to_scheme_credential_invalid_config():
config_dict = {
"invalid_field": "value",
}
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
}
scopes = ["scope1", "scope2"]
with pytest.raises(ValueError, match="Invalid OpenID Connect configuration"):
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
def test_openid_dict_to_scheme_credential_missing_credential_fields():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
}
credential_dict = {
"client_id": "client_id",
}
scopes = ["scope1", "scope2"]
with pytest.raises(
ValueError,
match="Missing required fields in credential_dict: client_secret",
):
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
@patch("requests.get")
def test_openid_url_to_scheme_credential(mock_get):
mock_response = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"userinfo_endpoint": "userinfo_url",
}
mock_get.return_value.json.return_value = mock_response
mock_get.return_value.raise_for_status.return_value = None
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_url_to_scheme_credential(
"openid_url", scopes, credential_dict
)
assert isinstance(scheme, OpenIdConnectWithConfig)
assert scheme.authorization_endpoint == "auth_url"
assert scheme.token_endpoint == "token_url"
assert scheme.scopes == scopes
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
assert credential.oauth2.client_id == "client_id"
assert credential.oauth2.client_secret == "client_secret"
assert credential.oauth2.redirect_uri == "redirect_uri"
mock_get.assert_called_once_with("openid_url", timeout=10)
@patch("requests.get")
def test_openid_url_to_scheme_credential_no_openid_url(mock_get):
mock_response = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"userinfo_endpoint": "userinfo_url",
}
mock_get.return_value.json.return_value = mock_response
mock_get.return_value.raise_for_status.return_value = None
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_url_to_scheme_credential(
"openid_url", scopes, credential_dict
)
assert scheme.openIdConnectUrl == "openid_url"
@patch("requests.get")
def test_openid_url_to_scheme_credential_request_exception(mock_get):
mock_get.side_effect = requests.exceptions.RequestException("Test Error")
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
with pytest.raises(
ValueError, match="Failed to fetch OpenID configuration from openid_url"
):
openid_url_to_scheme_credential("openid_url", [], credential_dict)
@patch("requests.get")
def test_openid_url_to_scheme_credential_invalid_json(mock_get):
mock_get.return_value.json.side_effect = ValueError("Invalid JSON")
mock_get.return_value.raise_for_status.return_value = None
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
with pytest.raises(
ValueError,
match=(
"Invalid JSON response from OpenID configuration endpoint openid_url"
),
):
openid_url_to_scheme_credential("openid_url", [], credential_dict)
def test_credential_to_param_api_key_header():
auth_scheme = APIKey(
**{"type": "apiKey", "in": "header", "name": "X-API-Key"}
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "X-API-Key"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "X-API-Key": "test_key"}
def test_credential_to_param_api_key_query():
auth_scheme = APIKey(**{"type": "apiKey", "in": "query", "name": "api_key"})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "api_key"
assert param.param_location == "query"
assert kwargs == {INTERNAL_AUTH_PREFIX + "api_key": "test_key"}
def test_credential_to_param_api_key_cookie():
auth_scheme = APIKey(
**{"type": "apiKey", "in": "cookie", "name": "session_id"}
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "session_id"
assert param.param_location == "cookie"
assert kwargs == {INTERNAL_AUTH_PREFIX + "session_id": "test_key"}
def test_credential_to_param_http_bearer():
auth_scheme = HTTPBearer(bearerFormat="JWT")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "Authorization"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
def test_credential_to_param_http_basic_not_supported():
auth_scheme = HTTPBase(scheme="basic")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="password"),
),
)
with pytest.raises(
NotImplementedError, match="Basic Authentication is not supported."
):
credential_to_param(auth_scheme, auth_credential)
def test_credential_to_param_http_invalid_credentials_no_http():
auth_scheme = HTTPBase(scheme="basic")
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP)
with pytest.raises(ValueError, match="Invalid HTTP auth credentials"):
credential_to_param(auth_scheme, auth_credential)
def test_credential_to_param_oauth2():
auth_scheme = OAuth2(flows={})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "Authorization"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
def test_credential_to_param_openid_connect():
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "Authorization"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
def test_credential_to_param_openid_no_credential():
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
param, kwargs = credential_to_param(auth_scheme, None)
assert param == None
assert kwargs == None
def test_credential_to_param_oauth2_no_credential():
auth_scheme = OAuth2(flows={})
param, kwargs = credential_to_param(auth_scheme, None)
assert param == None
assert kwargs == None
def test_dict_to_auth_scheme_api_key():
data = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.header
assert scheme.name == "X-API-Key"
def test_dict_to_auth_scheme_http_bearer():
data = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, HTTPBearer)
assert scheme.scheme == "bearer"
assert scheme.bearerFormat == "JWT"
def test_dict_to_auth_scheme_http_base():
data = {"type": "http", "scheme": "basic"}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, HTTPBase)
assert scheme.scheme == "basic"
def test_dict_to_auth_scheme_oauth2():
data = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://example.com/auth",
"tokenUrl": "https://example.com/token",
}
},
}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, OAuth2)
assert hasattr(scheme.flows, "authorizationCode")
def test_dict_to_auth_scheme_openid_connect():
data = {
"type": "openIdConnect",
"openIdConnectUrl": (
"https://example.com/.well-known/openid-configuration"
),
}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, OpenIdConnect)
assert (
scheme.openIdConnectUrl
== "https://example.com/.well-known/openid-configuration"
)
def test_dict_to_auth_scheme_missing_type():
data = {"in": "header", "name": "X-API-Key"}
with pytest.raises(
ValueError, match="Missing 'type' field in security scheme dictionary."
):
dict_to_auth_scheme(data)
def test_dict_to_auth_scheme_invalid_type():
data = {"type": "invalid", "in": "header", "name": "X-API-Key"}
with pytest.raises(ValueError, match="Invalid security scheme type: invalid"):
dict_to_auth_scheme(data)
def test_dict_to_auth_scheme_invalid_data():
data = {"type": "apiKey", "in": "header"} # Missing 'name'
with pytest.raises(ValueError, match="Invalid security scheme data"):
dict_to_auth_scheme(data)
if __name__ == "__main__":
pytest.main([__file__])