adk-python/tests/unittests/auth/test_auth_handler.py
Jack Sun 05142a07cc
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
2025-04-11 08:25:59 -07:00

579 lines
19 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from unittest.mock import patch
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
import pytest
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_handler import AuthHandler
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.auth.auth_tool import AuthConfig
# Mock classes for testing
class MockState(dict):
"""Mock State class for testing."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get(self, key, default=None):
return super().get(key, default)
class MockOAuth2Session:
"""Mock OAuth2Session for testing."""
def __init__(
self,
client_id=None,
client_secret=None,
scope=None,
redirect_uri=None,
state=None,
):
self.client_id = client_id
self.client_secret = client_secret
self.scope = scope
self.redirect_uri = redirect_uri
self.state = state
def create_authorization_url(self, url):
return f"{url}?client_id={self.client_id}&scope={self.scope}", "mock_state"
def fetch_token(
self,
token_endpoint,
authorization_response=None,
code=None,
grant_type=None,
):
return {
"access_token": "mock_access_token",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "mock_refresh_token",
}
# Fixtures for common test objects
@pytest.fixture
def oauth2_auth_scheme():
"""Create an OAuth2 auth scheme for testing."""
# Create the OAuthFlows object first
flows = OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://example.com/oauth2/authorize",
tokenUrl="https://example.com/oauth2/token",
scopes={"read": "Read access", "write": "Write access"},
)
)
# Then create the OAuth2 object with the flows
return OAuth2(flows=flows)
@pytest.fixture
def openid_auth_scheme():
"""Create an OpenID Connect auth scheme for testing."""
return OpenIdConnectWithConfig(
openIdConnectUrl="https://example.com/.well-known/openid-configuration",
authorization_endpoint="https://example.com/oauth2/authorize",
token_endpoint="https://example.com/oauth2/token",
scopes=["openid", "profile", "email"],
)
@pytest.fixture
def oauth2_credentials():
"""Create OAuth2 credentials for testing."""
return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
),
)
@pytest.fixture
def oauth2_credentials_with_token():
"""Create OAuth2 credentials with a token for testing."""
return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
token={
"access_token": "mock_access_token",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "mock_refresh_token",
},
),
)
@pytest.fixture
def oauth2_credentials_with_auth_uri():
"""Create OAuth2 credentials with an auth URI for testing."""
return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
auth_uri="https://example.com/oauth2/authorize?client_id=mock_client_id&scope=read,write",
state="mock_state",
),
)
@pytest.fixture
def oauth2_credentials_with_auth_code():
"""Create OAuth2 credentials with an auth code for testing."""
return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
auth_uri="https://example.com/oauth2/authorize?client_id=mock_client_id&scope=read,write",
state="mock_state",
auth_code="mock_auth_code",
auth_response_uri="https://example.com/callback?code=mock_auth_code&state=mock_state",
),
)
@pytest.fixture
def auth_config(oauth2_auth_scheme, oauth2_credentials):
"""Create an AuthConfig for testing."""
# Create a copy of the credentials for the exchanged_auth_credential
exchanged_credential = oauth2_credentials.model_copy(deep=True)
return AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=exchanged_credential,
)
@pytest.fixture
def auth_config_with_exchanged(
oauth2_auth_scheme, oauth2_credentials, oauth2_credentials_with_auth_uri
):
"""Create an AuthConfig with exchanged credentials for testing."""
return AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=oauth2_credentials_with_auth_uri,
)
@pytest.fixture
def auth_config_with_auth_code(
oauth2_auth_scheme, oauth2_credentials, oauth2_credentials_with_auth_code
):
"""Create an AuthConfig with auth code for testing."""
return AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=oauth2_credentials_with_auth_code,
)
class TestAuthHandlerInit:
"""Tests for the AuthHandler initialization."""
def test_init(self, auth_config):
"""Test the initialization of AuthHandler."""
handler = AuthHandler(auth_config)
assert handler.auth_config == auth_config
class TestGetCredentialKey:
"""Tests for the get_credential_key method."""
def test_get_credential_key(self, auth_config):
"""Test generating a unique credential key."""
handler = AuthHandler(auth_config)
key = handler.get_credential_key()
assert key.startswith("temp:adk_oauth2_")
assert "_oauth2_" in key
def test_get_credential_key_with_extras(self, auth_config):
"""Test generating a key when model_extra exists."""
# Add model_extra to test cleanup
original_key = AuthHandler(auth_config).get_credential_key()
key = AuthHandler(auth_config).get_credential_key()
auth_config.auth_scheme.model_extra["extra_field"] = "value"
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
assert original_key == key
assert "extra_field" in auth_config.auth_scheme.model_extra
assert "extra_field" in auth_config.raw_auth_credential.model_extra
class TestGenerateAuthUri:
"""Tests for the generate_auth_uri method."""
@pytest.mark.skip(reason="broken tests")
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_oauth2(self, auth_config):
"""Test generating an auth URI for OAuth2."""
handler = AuthHandler(auth_config)
result = handler.generate_auth_uri()
assert result.oauth2.auth_uri.startswith(
"https://example.com/oauth2/authorize"
)
assert "client_id=mock_client_id" in result.oauth2.auth_uri
assert result.oauth2.state == "mock_state"
@pytest.mark.skip(reason="broken tests")
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_openid(
self, openid_auth_scheme, oauth2_credentials
):
"""Test generating an auth URI for OpenID Connect."""
# Create a copy for the exchanged credential
exchanged = oauth2_credentials.model_copy(deep=True)
config = AuthConfig(
auth_scheme=openid_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=exchanged,
)
handler = AuthHandler(config)
result = handler.generate_auth_uri()
assert result.oauth2.auth_uri.startswith(
"https://example.com/oauth2/authorize"
)
assert "client_id=mock_client_id" in result.oauth2.auth_uri
assert result.oauth2.state == "mock_state"
class TestGenerateAuthRequest:
"""Tests for the generate_auth_request method."""
def test_non_oauth_scheme(self):
"""Test with a non-OAuth auth scheme."""
# Use a SecurityBase instance without using APIKey which has validation issues
api_key_scheme = APIKey(**{"name": "test_api_key", "in": APIKeyIn.header})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
)
# Create a copy for the exchanged credential
exchanged = credential.model_copy(deep=True)
config = AuthConfig(
auth_scheme=api_key_scheme,
raw_auth_credential=credential,
exchanged_auth_credential=exchanged,
)
handler = AuthHandler(config)
result = handler.generate_auth_request()
assert result == config
def test_with_existing_auth_uri(self, auth_config_with_exchanged):
"""Test when auth_uri already exists in exchanged credential."""
handler = AuthHandler(auth_config_with_exchanged)
result = handler.generate_auth_request()
assert (
result.exchanged_auth_credential.oauth2.auth_uri
== auth_config_with_exchanged.exchanged_auth_credential.oauth2.auth_uri
)
def test_missing_raw_credential(self, oauth2_auth_scheme):
"""Test when raw_auth_credential is missing."""
config = AuthConfig(
auth_scheme=oauth2_auth_scheme,
)
handler = AuthHandler(config)
with pytest.raises(ValueError, match="requires auth_credential"):
handler.generate_auth_request()
def test_missing_oauth2_in_raw_credential(self, oauth2_auth_scheme):
"""Test when oauth2 is missing in raw_auth_credential."""
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
)
# Create a copy for the exchanged credential
exchanged = credential.model_copy(deep=True)
config = AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=credential,
exchanged_auth_credential=exchanged,
)
handler = AuthHandler(config)
with pytest.raises(ValueError, match="requires oauth2 in auth_credential"):
handler.generate_auth_request()
def test_auth_uri_in_raw_credential(
self, oauth2_auth_scheme, oauth2_credentials_with_auth_uri
):
"""Test when auth_uri exists in raw_credential."""
config = AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=oauth2_credentials_with_auth_uri,
exchanged_auth_credential=oauth2_credentials_with_auth_uri.model_copy(
deep=True
),
)
handler = AuthHandler(config)
result = handler.generate_auth_request()
assert (
result.exchanged_auth_credential.oauth2.auth_uri
== oauth2_credentials_with_auth_uri.oauth2.auth_uri
)
def test_missing_client_credentials(self, oauth2_auth_scheme):
"""Test when client_id or client_secret is missing."""
bad_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(redirect_uri="https://example.com/callback"),
)
# Create a copy for the exchanged credential
exchanged = bad_credential.model_copy(deep=True)
config = AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=bad_credential,
exchanged_auth_credential=exchanged,
)
handler = AuthHandler(config)
with pytest.raises(
ValueError, match="requires both client_id and client_secret"
):
handler.generate_auth_request()
@patch("google.adk.auth.auth_handler.AuthHandler.generate_auth_uri")
def test_generate_new_auth_uri(self, mock_generate_auth_uri, auth_config):
"""Test generating a new auth URI."""
mock_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
auth_uri="https://example.com/generated",
state="generated_state",
),
)
mock_generate_auth_uri.return_value = mock_credential
handler = AuthHandler(auth_config)
result = handler.generate_auth_request()
assert mock_generate_auth_uri.called
assert result.exchanged_auth_credential == mock_credential
class TestGetAuthResponse:
"""Tests for the get_auth_response method."""
def test_get_auth_response_exists(
self, auth_config, oauth2_credentials_with_auth_uri
):
"""Test retrieving an existing auth response from state."""
handler = AuthHandler(auth_config)
state = MockState()
# Store a credential in the state
credential_key = handler.get_credential_key()
state[credential_key] = oauth2_credentials_with_auth_uri
result = handler.get_auth_response(state)
assert result == oauth2_credentials_with_auth_uri
def test_get_auth_response_not_exists(self, auth_config):
"""Test retrieving a non-existent auth response from state."""
handler = AuthHandler(auth_config)
state = MockState()
result = handler.get_auth_response(state)
assert result is None
class TestParseAndStoreAuthResponse:
"""Tests for the parse_and_store_auth_response method."""
def test_non_oauth_scheme(self, auth_config_with_exchanged):
"""Test with a non-OAuth auth scheme."""
# Modify the auth scheme type to be non-OAuth
auth_config = copy.deepcopy(auth_config_with_exchanged)
auth_config.auth_scheme = APIKey(
**{"name": "test_api_key", "in": APIKeyIn.header}
)
handler = AuthHandler(auth_config)
state = MockState()
handler.parse_and_store_auth_response(state)
credential_key = handler.get_credential_key()
assert state[credential_key] == auth_config.exchanged_auth_credential
@patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token")
def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
"""Test with an OAuth auth scheme."""
mock_exchange_token.return_value = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}),
)
handler = AuthHandler(auth_config_with_exchanged)
state = MockState()
handler.parse_and_store_auth_response(state)
credential_key = handler.get_credential_key()
assert state[credential_key] == mock_exchange_token.return_value
assert mock_exchange_token.called
class TestExchangeAuthToken:
"""Tests for the exchange_auth_token method."""
def test_token_exchange_not_supported(
self, auth_config_with_auth_code, monkeypatch
):
"""Test when token exchange is not supported."""
monkeypatch.setattr(
"google.adk.auth.auth_handler.SUPPORT_TOKEN_EXCHANGE", False
)
handler = AuthHandler(auth_config_with_auth_code)
result = handler.exchange_auth_token()
assert result == auth_config_with_auth_code.exchanged_auth_credential
def test_openid_missing_token_endpoint(
self, openid_auth_scheme, oauth2_credentials_with_auth_code
):
"""Test OpenID Connect without a token endpoint."""
# Create a scheme without token_endpoint
scheme_without_token = copy.deepcopy(openid_auth_scheme)
delattr(scheme_without_token, "token_endpoint")
config = AuthConfig(
auth_scheme=scheme_without_token,
raw_auth_credential=oauth2_credentials_with_auth_code,
exchanged_auth_credential=oauth2_credentials_with_auth_code,
)
handler = AuthHandler(config)
result = handler.exchange_auth_token()
assert result == oauth2_credentials_with_auth_code
def test_oauth2_missing_token_url(
self, oauth2_auth_scheme, oauth2_credentials_with_auth_code
):
"""Test OAuth2 without a token URL."""
# Create a scheme without tokenUrl
scheme_without_token = copy.deepcopy(oauth2_auth_scheme)
scheme_without_token.flows.authorizationCode.tokenUrl = None
config = AuthConfig(
auth_scheme=scheme_without_token,
raw_auth_credential=oauth2_credentials_with_auth_code,
exchanged_auth_credential=oauth2_credentials_with_auth_code,
)
handler = AuthHandler(config)
result = handler.exchange_auth_token()
assert result == oauth2_credentials_with_auth_code
def test_non_oauth_scheme(self, auth_config_with_auth_code):
"""Test with a non-OAuth auth scheme."""
# Modify the auth scheme type to be non-OAuth
auth_config = copy.deepcopy(auth_config_with_auth_code)
auth_config.auth_scheme = APIKey(
**{"name": "test_api_key", "in": APIKeyIn.header}
)
handler = AuthHandler(auth_config)
result = handler.exchange_auth_token()
assert result == auth_config.exchanged_auth_credential
def test_missing_credentials(self, oauth2_auth_scheme):
"""Test with missing credentials."""
empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2)
config = AuthConfig(
auth_scheme=oauth2_auth_scheme,
exchanged_auth_credential=empty_credential,
)
handler = AuthHandler(config)
result = handler.exchange_auth_token()
assert result == empty_credential
def test_credentials_with_token(
self, auth_config, oauth2_credentials_with_token
):
"""Test when credentials already have a token."""
config = AuthConfig(
auth_scheme=auth_config.auth_scheme,
raw_auth_credential=auth_config.raw_auth_credential,
exchanged_auth_credential=oauth2_credentials_with_token,
)
handler = AuthHandler(config)
result = handler.exchange_auth_token()
assert result == oauth2_credentials_with_token
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_successful_token_exchange(self, auth_config_with_auth_code):
"""Test a successful token exchange."""
handler = AuthHandler(auth_config_with_auth_code)
result = handler.exchange_auth_token()
assert result.oauth2.token["access_token"] == "mock_access_token"
assert result.oauth2.token["refresh_token"] == "mock_refresh_token"
assert result.auth_type == AuthCredentialTypes.OAUTH2