Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

View File

@@ -0,0 +1,145 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for AutoAuthCredentialExchanger."""
from typing import Dict
from typing import Optional
from typing import Type
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.oauth2_exchanger import OAuth2CredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import pytest
class MockCredentialExchanger(BaseAuthCredentialExchanger):
"""Mock credential exchanger for testing."""
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Mock exchange credential method."""
return auth_credential
@pytest.fixture
def auto_exchanger():
"""Fixture for creating an AutoAuthCredentialExchanger instance."""
return AutoAuthCredentialExchanger()
@pytest.fixture
def auth_scheme():
"""Fixture for creating a mock AuthScheme instance."""
scheme = MagicMock(spec=AuthScheme)
return scheme
def test_init_with_custom_exchangers():
"""Test initialization with custom exchangers."""
custom_exchangers: Dict[str, Type[BaseAuthCredentialExchanger]] = {
AuthCredentialTypes.API_KEY: MockCredentialExchanger
}
auto_exchanger = AutoAuthCredentialExchanger(
custom_exchangers=custom_exchangers
)
assert (
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY]
== MockCredentialExchanger
)
assert (
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT]
== OAuth2CredentialExchanger
)
def test_exchange_credential_no_auth_credential(auto_exchanger, auth_scheme):
"""Test exchange_credential with no auth_credential."""
assert auto_exchanger.exchange_credential(auth_scheme, None) is None
def test_exchange_credential_no_exchange(auto_exchanger, auth_scheme):
"""Test exchange_credential with NoExchangeCredentialExchanger."""
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == auth_credential
def test_exchange_credential_open_id_connect(auto_exchanger, auth_scheme):
"""Test exchange_credential with OpenID Connect scheme."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT
)
mock_exchanger = MagicMock(spec=OAuth2CredentialExchanger)
mock_exchanger.exchange_credential.return_value = "exchanged_credential"
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT] = (
lambda: mock_exchanger
)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == "exchanged_credential"
mock_exchanger.exchange_credential.assert_called_once_with(
auth_scheme, auth_credential
)
def test_exchange_credential_service_account(auto_exchanger, auth_scheme):
"""Test exchange_credential with Service Account scheme."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT
)
mock_exchanger = MagicMock(spec=ServiceAccountCredentialExchanger)
mock_exchanger.exchange_credential.return_value = "exchanged_credential_sa"
auto_exchanger.exchangers[AuthCredentialTypes.SERVICE_ACCOUNT] = (
lambda: mock_exchanger
)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == "exchanged_credential_sa"
mock_exchanger.exchange_credential.assert_called_once_with(
auth_scheme, auth_credential
)
def test_exchange_credential_custom_exchanger(auto_exchanger, auth_scheme):
"""Test that exchange_credential calls the correct (custom) exchanger."""
# Use a custom exchanger via the initialization
mock_exchanger = MagicMock(spec=MockCredentialExchanger)
mock_exchanger.exchange_credential.return_value = "custom_credential"
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY] = (
lambda: mock_exchanger
)
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
assert result == "custom_credential"
mock_exchanger.exchange_credential.assert_called_once_with(
auth_scheme, auth_credential
)

View File

@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the BaseAuthCredentialExchanger class."""
from typing import Optional
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
import pytest
class MockAuthCredentialExchanger(BaseAuthCredentialExchanger):
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
return AuthCredential(token="some-token")
class TestBaseAuthCredentialExchanger:
"""Tests for the BaseAuthCredentialExchanger class."""
@pytest.fixture
def base_exchanger(self):
return BaseAuthCredentialExchanger()
@pytest.fixture
def auth_scheme(self):
scheme = MagicMock(spec=AuthScheme)
scheme.type = "apiKey"
scheme.name = "x-api-key"
return scheme
def test_exchange_credential_not_implemented(
self, base_exchanger, auth_scheme
):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, token="some-token"
)
with pytest.raises(NotImplementedError) as exc_info:
base_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Subclasses must implement exchange_credential." in str(
exc_info.value
)
def test_auth_credential_missing_error(self):
error_message = "Test missing credential"
error = AuthCredentialMissingError(error_message)
# assert error.message == error_message
assert str(error) == error_message

View File

@@ -0,0 +1,153 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for OAuth2CredentialExchanger."""
import copy
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.tools.openapi_tool.auth.credential_exchangers import OAuth2CredentialExchanger
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
import pytest
@pytest.fixture
def oauth2_exchanger():
return OAuth2CredentialExchanger()
@pytest.fixture
def auth_scheme():
openid_config = OpenIdConnectWithConfig(
type_=AuthSchemeType.openIdConnect,
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token",
scopes=["openid", "profile"],
)
return openid_config
def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
),
)
# Check that the method does not raise an exception
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
def test_check_scheme_credential_type_missing_credential(
oauth2_exchanger, auth_scheme
):
# Test case: auth_credential is None
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
assert "auth_credential is empty" in str(exc_info.value)
def test_check_scheme_credential_type_invalid_scheme_type(
oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
):
"""Test case: Invalid AuthSchemeType."""
# Test case: Invalid AuthSchemeType
invalid_scheme = copy.deepcopy(auth_scheme)
invalid_scheme.type_ = AuthSchemeType.apiKey
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
),
)
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(
invalid_scheme, auth_credential
)
assert "Invalid security scheme" in str(exc_info.value)
def test_check_scheme_credential_type_missing_openid_connect(
oauth2_exchanger, auth_scheme
):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
)
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
assert "auth_credential is not configured with oauth2" in str(exc_info.value)
def test_generate_auth_token_success(
oauth2_exchanger, auth_scheme, monkeypatch
):
"""Test case: Successful generation of access token."""
# Test case: Successful generation of access token
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"},
),
)
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "test_access_token"
def test_exchange_credential_generate_auth_token(
oauth2_exchanger, auth_scheme, monkeypatch
):
"""Test exchange_credential when auth_response_uri is present."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"},
),
)
updated_credential = oauth2_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "test_access_token"
def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
"""Test exchange_credential when auth_credential is missing."""
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger.exchange_credential(auth_scheme, None)
assert "auth_credential is empty. Please create AuthCredential using" in str(
exc_info.value
)

View File

@@ -0,0 +1,196 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the service account credential exchanger."""
from unittest.mock import MagicMock
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_credential import ServiceAccountCredential
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import google.auth
import pytest
@pytest.fixture
def service_account_exchanger():
return ServiceAccountCredentialExchanger()
@pytest.fixture
def auth_scheme():
scheme = MagicMock(spec=AuthScheme)
scheme.type_ = AuthSchemeType.oauth2
scheme.description = "Google Service Account"
return scheme
def test_exchange_credential_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange of service account credentials."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
# Mock the from_service_account_info method
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
monkeypatch.setattr(
target_path,
mock_from_service_account_info,
)
# Mock the refresh method
mock_credentials.refresh = MagicMock()
# Create a valid AuthCredential with service account info
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
mock_from_service_account_info.assert_called_once()
mock_credentials.refresh.assert_called_once()
def test_exchange_credential_use_default_credential_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange of service account credentials using default credential."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
mock_google_auth_default = MagicMock(
return_value=(mock_credentials, "test_project")
)
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
mock_google_auth_default.assert_called_once()
mock_credentials.refresh.assert_called_once()
def test_exchange_credential_missing_auth_credential(
service_account_exchanger, auth_scheme
):
"""Test missing auth credential during exchange."""
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, None)
assert "Service account credentials are missing" in str(exc_info.value)
def test_exchange_credential_missing_service_account_info(
service_account_exchanger, auth_scheme
):
"""Test missing service account info during exchange."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Service account credentials are missing" in str(exc_info.value)
def test_exchange_credential_exchange_failure(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test failure during service account token exchange."""
mock_from_service_account_info = MagicMock(
side_effect=Exception("Failed to load credentials")
)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
monkeypatch.setattr(
target_path,
mock_from_service_account_info,
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account token" in str(exc_info.value)
mock_from_service_account_info.assert_called_once()

View File

@@ -0,0 +1,573 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import HTTPBase
from fastapi.openapi.models import HTTPBearer
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OpenIdConnect
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_credential import ServiceAccountCredential
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
from google.adk.tools.openapi_tool.auth.auth_helpers import credential_to_param
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
from google.adk.tools.openapi_tool.auth.auth_helpers import INTERNAL_AUTH_PREFIX
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_url_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_dict_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
import pytest
import requests
def test_token_to_scheme_credential_api_key_header():
scheme, credential = token_to_scheme_credential(
"apikey", "header", "X-API-Key", "test_key"
)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.header
assert scheme.name == "X-API-Key"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
def test_token_to_scheme_credential_api_key_query():
scheme, credential = token_to_scheme_credential(
"apikey", "query", "api_key", "test_key"
)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.query
assert scheme.name == "api_key"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
def test_token_to_scheme_credential_api_key_cookie():
scheme, credential = token_to_scheme_credential(
"apikey", "cookie", "session_id", "test_key"
)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.cookie
assert scheme.name == "session_id"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
def test_token_to_scheme_credential_api_key_no_credential():
scheme, credential = token_to_scheme_credential(
"apikey", "cookie", "session_id"
)
assert isinstance(scheme, APIKey)
assert credential is None
def test_token_to_scheme_credential_oauth2_token():
scheme, credential = token_to_scheme_credential(
"oauth2Token", "header", "Authorization", "test_token"
)
assert isinstance(scheme, HTTPBearer)
assert scheme.bearerFormat == "JWT"
assert credential == AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
def test_token_to_scheme_credential_oauth2_no_credential():
scheme, credential = token_to_scheme_credential(
"oauth2Token", "header", "Authorization"
)
assert isinstance(scheme, HTTPBearer)
assert credential is None
def test_service_account_dict_to_scheme_credential():
config = {
"type": "service_account",
"project_id": "project_id",
"private_key_id": "private_key_id",
"private_key": "private_key",
"client_email": "client_email",
"client_id": "client_id",
"auth_uri": "auth_uri",
"token_uri": "token_uri",
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
"client_x509_cert_url": "client_x509_cert_url",
"universe_domain": "universe_domain",
}
scopes = ["scope1", "scope2"]
scheme, credential = service_account_dict_to_scheme_credential(config, scopes)
assert isinstance(scheme, HTTPBearer)
assert scheme.bearerFormat == "JWT"
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
assert credential.service_account.scopes == scopes
assert (
credential.service_account.service_account_credential.project_id
== "project_id"
)
def test_service_account_scheme_credential():
config = ServiceAccount(
service_account_credential=ServiceAccountCredential(
type="service_account",
project_id="project_id",
private_key_id="private_key_id",
private_key="private_key",
client_email="client_email",
client_id="client_id",
auth_uri="auth_uri",
token_uri="token_uri",
auth_provider_x509_cert_url="auth_provider_x509_cert_url",
client_x509_cert_url="client_x509_cert_url",
universe_domain="universe_domain",
),
scopes=["scope1", "scope2"],
)
scheme, credential = service_account_scheme_credential(config)
assert isinstance(scheme, HTTPBearer)
assert scheme.bearerFormat == "JWT"
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
assert credential.service_account == config
def test_openid_dict_to_scheme_credential():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"openIdConnectUrl": "openid_url",
}
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_dict_to_scheme_credential(
config_dict, scopes, credential_dict
)
assert isinstance(scheme, OpenIdConnectWithConfig)
assert scheme.authorization_endpoint == "auth_url"
assert scheme.token_endpoint == "token_url"
assert scheme.scopes == scopes
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
assert credential.oauth2.client_id == "client_id"
assert credential.oauth2.client_secret == "client_secret"
assert credential.oauth2.redirect_uri == "redirect_uri"
def test_openid_dict_to_scheme_credential_no_openid_url():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
}
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_dict_to_scheme_credential(
config_dict, scopes, credential_dict
)
assert scheme.openIdConnectUrl == ""
def test_openid_dict_to_scheme_credential_google_oauth_credential():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"openIdConnectUrl": "openid_url",
}
credential_dict = {
"web": {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_dict_to_scheme_credential(
config_dict, scopes, credential_dict
)
assert isinstance(scheme, OpenIdConnectWithConfig)
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
assert credential.oauth2.client_id == "client_id"
assert credential.oauth2.client_secret == "client_secret"
assert credential.oauth2.redirect_uri == "redirect_uri"
def test_openid_dict_to_scheme_credential_invalid_config():
config_dict = {
"invalid_field": "value",
}
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
}
scopes = ["scope1", "scope2"]
with pytest.raises(ValueError, match="Invalid OpenID Connect configuration"):
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
def test_openid_dict_to_scheme_credential_missing_credential_fields():
config_dict = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
}
credential_dict = {
"client_id": "client_id",
}
scopes = ["scope1", "scope2"]
with pytest.raises(
ValueError,
match="Missing required fields in credential_dict: client_secret",
):
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
@patch("requests.get")
def test_openid_url_to_scheme_credential(mock_get):
mock_response = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"userinfo_endpoint": "userinfo_url",
}
mock_get.return_value.json.return_value = mock_response
mock_get.return_value.raise_for_status.return_value = None
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_url_to_scheme_credential(
"openid_url", scopes, credential_dict
)
assert isinstance(scheme, OpenIdConnectWithConfig)
assert scheme.authorization_endpoint == "auth_url"
assert scheme.token_endpoint == "token_url"
assert scheme.scopes == scopes
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
assert credential.oauth2.client_id == "client_id"
assert credential.oauth2.client_secret == "client_secret"
assert credential.oauth2.redirect_uri == "redirect_uri"
mock_get.assert_called_once_with("openid_url", timeout=10)
@patch("requests.get")
def test_openid_url_to_scheme_credential_no_openid_url(mock_get):
mock_response = {
"authorization_endpoint": "auth_url",
"token_endpoint": "token_url",
"userinfo_endpoint": "userinfo_url",
}
mock_get.return_value.json.return_value = mock_response
mock_get.return_value.raise_for_status.return_value = None
credential_dict = {
"client_id": "client_id",
"client_secret": "client_secret",
"redirect_uri": "redirect_uri",
}
scopes = ["scope1", "scope2"]
scheme, credential = openid_url_to_scheme_credential(
"openid_url", scopes, credential_dict
)
assert scheme.openIdConnectUrl == "openid_url"
@patch("requests.get")
def test_openid_url_to_scheme_credential_request_exception(mock_get):
mock_get.side_effect = requests.exceptions.RequestException("Test Error")
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
with pytest.raises(
ValueError, match="Failed to fetch OpenID configuration from openid_url"
):
openid_url_to_scheme_credential("openid_url", [], credential_dict)
@patch("requests.get")
def test_openid_url_to_scheme_credential_invalid_json(mock_get):
mock_get.return_value.json.side_effect = ValueError("Invalid JSON")
mock_get.return_value.raise_for_status.return_value = None
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
with pytest.raises(
ValueError,
match=(
"Invalid JSON response from OpenID configuration endpoint openid_url"
),
):
openid_url_to_scheme_credential("openid_url", [], credential_dict)
def test_credential_to_param_api_key_header():
auth_scheme = APIKey(
**{"type": "apiKey", "in": "header", "name": "X-API-Key"}
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "X-API-Key"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "X-API-Key": "test_key"}
def test_credential_to_param_api_key_query():
auth_scheme = APIKey(**{"type": "apiKey", "in": "query", "name": "api_key"})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "api_key"
assert param.param_location == "query"
assert kwargs == {INTERNAL_AUTH_PREFIX + "api_key": "test_key"}
def test_credential_to_param_api_key_cookie():
auth_scheme = APIKey(
**{"type": "apiKey", "in": "cookie", "name": "session_id"}
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "session_id"
assert param.param_location == "cookie"
assert kwargs == {INTERNAL_AUTH_PREFIX + "session_id": "test_key"}
def test_credential_to_param_http_bearer():
auth_scheme = HTTPBearer(bearerFormat="JWT")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "Authorization"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
def test_credential_to_param_http_basic_not_supported():
auth_scheme = HTTPBase(scheme="basic")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="password"),
),
)
with pytest.raises(
NotImplementedError, match="Basic Authentication is not supported."
):
credential_to_param(auth_scheme, auth_credential)
def test_credential_to_param_http_invalid_credentials_no_http():
auth_scheme = HTTPBase(scheme="basic")
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP)
with pytest.raises(ValueError, match="Invalid HTTP auth credentials"):
credential_to_param(auth_scheme, auth_credential)
def test_credential_to_param_oauth2():
auth_scheme = OAuth2(flows={})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "Authorization"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
def test_credential_to_param_openid_connect():
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="test_token")
),
)
param, kwargs = credential_to_param(auth_scheme, auth_credential)
assert param.original_name == "Authorization"
assert param.param_location == "header"
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
def test_credential_to_param_openid_no_credential():
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
param, kwargs = credential_to_param(auth_scheme, None)
assert param == None
assert kwargs == None
def test_credential_to_param_oauth2_no_credential():
auth_scheme = OAuth2(flows={})
param, kwargs = credential_to_param(auth_scheme, None)
assert param == None
assert kwargs == None
def test_dict_to_auth_scheme_api_key():
data = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, APIKey)
assert scheme.type_ == AuthSchemeType.apiKey
assert scheme.in_ == APIKeyIn.header
assert scheme.name == "X-API-Key"
def test_dict_to_auth_scheme_http_bearer():
data = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, HTTPBearer)
assert scheme.scheme == "bearer"
assert scheme.bearerFormat == "JWT"
def test_dict_to_auth_scheme_http_base():
data = {"type": "http", "scheme": "basic"}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, HTTPBase)
assert scheme.scheme == "basic"
def test_dict_to_auth_scheme_oauth2():
data = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://example.com/auth",
"tokenUrl": "https://example.com/token",
}
},
}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, OAuth2)
assert hasattr(scheme.flows, "authorizationCode")
def test_dict_to_auth_scheme_openid_connect():
data = {
"type": "openIdConnect",
"openIdConnectUrl": (
"https://example.com/.well-known/openid-configuration"
),
}
scheme = dict_to_auth_scheme(data)
assert isinstance(scheme, OpenIdConnect)
assert (
scheme.openIdConnectUrl
== "https://example.com/.well-known/openid-configuration"
)
def test_dict_to_auth_scheme_missing_type():
data = {"in": "header", "name": "X-API-Key"}
with pytest.raises(
ValueError, match="Missing 'type' field in security scheme dictionary."
):
dict_to_auth_scheme(data)
def test_dict_to_auth_scheme_invalid_type():
data = {"type": "invalid", "in": "header", "name": "X-API-Key"}
with pytest.raises(ValueError, match="Invalid security scheme type: invalid"):
dict_to_auth_scheme(data)
def test_dict_to_auth_scheme_invalid_data():
data = {"type": "apiKey", "in": "header"} # Missing 'name'
with pytest.raises(ValueError, match="Invalid security scheme data"):
dict_to_auth_scheme(data)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,436 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from fastapi.openapi.models import Response, Schema
from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.common.common import PydocHelper
from google.adk.tools.openapi_tool.common.common import rename_python_keywords
from google.adk.tools.openapi_tool.common.common import to_snake_case
from google.adk.tools.openapi_tool.common.common import TypeHintHelper
import pytest
def dict_to_responses(input: Dict[str, Any]) -> Dict[str, Response]:
return {k: Response.model_validate(input[k]) for k in input}
class TestToSnakeCase:
@pytest.mark.parametrize(
'input_str, expected_output',
[
('lowerCamelCase', 'lower_camel_case'),
('UpperCamelCase', 'upper_camel_case'),
('space separated', 'space_separated'),
('REST API', 'rest_api'),
('Mixed_CASE with_Spaces', 'mixed_case_with_spaces'),
('__init__', 'init'),
('APIKey', 'api_key'),
('SomeLongURL', 'some_long_url'),
('CONSTANT_CASE', 'constant_case'),
('already_snake_case', 'already_snake_case'),
('single', 'single'),
('', ''),
(' spaced ', 'spaced'),
('with123numbers', 'with123numbers'),
('With_Mixed_123_and_SPACES', 'with_mixed_123_and_spaces'),
('HTMLParser', 'html_parser'),
('HTTPResponseCode', 'http_response_code'),
('a_b_c', 'a_b_c'),
('A_B_C', 'a_b_c'),
('fromAtoB', 'from_ato_b'),
('XMLHTTPRequest', 'xmlhttp_request'),
('_leading', 'leading'),
('trailing_', 'trailing'),
(' leading_and_trailing_ ', 'leading_and_trailing'),
('Multiple___Underscores', 'multiple_underscores'),
(' spaces_and___underscores ', 'spaces_and_underscores'),
(' _mixed_Case ', 'mixed_case'),
('123Start', '123_start'),
('End123', 'end123'),
('Mid123dle', 'mid123dle'),
],
)
def test_to_snake_case(self, input_str, expected_output):
assert to_snake_case(input_str) == expected_output
class TestRenamePythonKeywords:
@pytest.mark.parametrize(
'input_str, expected_output',
[
('in', 'param_in'),
('for', 'param_for'),
('class', 'param_class'),
('normal', 'normal'),
('param_if', 'param_if'),
('', ''),
],
)
def test_rename_python_keywords(self, input_str, expected_output):
assert rename_python_keywords(input_str) == expected_output
class TestApiParameter:
def test_api_parameter_initialization(self):
schema = Schema(type='string', description='A string parameter')
param = ApiParameter(
original_name='testParam',
description='A string description',
param_location='query',
param_schema=schema,
)
assert param.original_name == 'testParam'
assert param.param_location == 'query'
assert param.param_schema.type == 'string'
assert param.param_schema.description == 'A string parameter'
assert param.py_name == 'test_param'
assert param.type_hint == 'str'
assert param.type_value == str
assert param.description == 'A string description'
def test_api_parameter_keyword_rename(self):
schema = Schema(type='string')
param = ApiParameter(
original_name='in',
param_location='query',
param_schema=schema,
)
assert param.py_name == 'param_in'
def test_api_parameter_custom_py_name(self):
schema = Schema(type='integer')
param = ApiParameter(
original_name='testParam',
param_location='query',
param_schema=schema,
py_name='custom_name',
)
assert param.py_name == 'custom_name'
def test_api_parameter_str_representation(self):
schema = Schema(type='number')
param = ApiParameter(
original_name='testParam',
param_location='query',
param_schema=schema,
)
assert str(param) == 'test_param: float'
def test_api_parameter_to_arg_string(self):
schema = Schema(type='boolean')
param = ApiParameter(
original_name='testParam',
param_location='query',
param_schema=schema,
)
assert param.to_arg_string() == 'test_param=test_param'
def test_api_parameter_to_dict_property(self):
schema = Schema(type='string')
param = ApiParameter(
original_name='testParam',
param_location='path',
param_schema=schema,
)
assert param.to_dict_property() == '"test_param": test_param'
def test_api_parameter_model_serializer(self):
schema = Schema(type='string', description='test description')
param = ApiParameter(
original_name='TestParam',
param_location='path',
param_schema=schema,
py_name='test_param_custom',
description='test description',
)
serialized_param = param.model_dump(mode='json', exclude_none=True)
assert serialized_param == {
'original_name': 'TestParam',
'param_location': 'path',
'param_schema': {'type': 'string', 'description': 'test description'},
'description': 'test description',
'py_name': 'test_param_custom',
}
@pytest.mark.parametrize(
'schema, expected_type_value, expected_type_hint',
[
({'type': 'integer'}, int, 'int'),
({'type': 'number'}, float, 'float'),
({'type': 'boolean'}, bool, 'bool'),
({'type': 'string'}, str, 'str'),
(
{'type': 'string', 'format': 'date'},
str,
'str',
),
(
{'type': 'string', 'format': 'date-time'},
str,
'str',
),
(
{'type': 'array', 'items': {'type': 'integer'}},
List[int],
'List[int]',
),
(
{'type': 'array', 'items': {'type': 'string'}},
List[str],
'List[str]',
),
(
{
'type': 'array',
'items': {'type': 'object'},
},
List[Dict[str, Any]],
'List[Dict[str, Any]]',
),
({'type': 'object'}, Dict[str, Any], 'Dict[str, Any]'),
({'type': 'unknown'}, Any, 'Any'),
({}, Any, 'Any'),
],
)
def test_api_parameter_type_hint_helper(
self, schema, expected_type_value, expected_type_hint
):
param = ApiParameter(
original_name='test', param_location='query', param_schema=schema
)
assert param.type_value == expected_type_value
assert param.type_hint == expected_type_hint
assert (
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
)
assert (
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
)
def test_api_parameter_description(self):
schema = Schema(type='string')
param = ApiParameter(
original_name='param1',
param_location='query',
param_schema=schema,
description='The description',
)
assert param.description == 'The description'
def test_api_parameter_description_use_schema_fallback(self):
schema = Schema(type='string', description='The description')
param = ApiParameter(
original_name='param1',
param_location='query',
param_schema=schema,
)
assert param.description == 'The description'
class TestTypeHintHelper:
@pytest.mark.parametrize(
'schema, expected_type_value, expected_type_hint',
[
({'type': 'integer'}, int, 'int'),
({'type': 'number'}, float, 'float'),
({'type': 'string'}, str, 'str'),
(
{
'type': 'array',
'items': {'type': 'string'},
},
List[str],
'List[str]',
),
],
)
def test_get_type_value_and_hint(
self, schema, expected_type_value, expected_type_hint
):
param = ApiParameter(
original_name='test_param',
param_location='query',
param_schema=schema,
description='Test parameter',
)
assert (
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
)
assert (
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
)
class TestPydocHelper:
def test_generate_param_doc_simple(self):
schema = Schema(type='string')
param = ApiParameter(
original_name='test_param',
param_location='query',
param_schema=schema,
description='Test description',
)
expected_doc = 'test_param (str): Test description'
assert PydocHelper.generate_param_doc(param) == expected_doc
def test_generate_param_doc_no_description(self):
schema = Schema(type='integer')
param = ApiParameter(
original_name='test_param',
param_location='query',
param_schema=schema,
)
expected_doc = 'test_param (int): '
assert PydocHelper.generate_param_doc(param) == expected_doc
def test_generate_param_doc_object(self):
schema = Schema(
type='object',
properties={
'prop1': {'type': 'string', 'description': 'Prop1 desc'},
'prop2': {'type': 'integer'},
},
)
param = ApiParameter(
original_name='test_param',
param_location='query',
param_schema=schema,
description='Test object parameter',
)
expected_doc = (
'test_param (Dict[str, Any]): Test object parameter Object'
' properties:\n prop1 (str): Prop1 desc\n prop2'
' (int): \n'
)
assert PydocHelper.generate_param_doc(param) == expected_doc
def test_generate_param_doc_object_no_properties(self):
schema = Schema(type='object', description='A test schema')
param = ApiParameter(
original_name='test_param',
param_location='query',
param_schema=schema,
description='The description.',
)
expected_doc = 'test_param (Dict[str, Any]): The description.'
assert PydocHelper.generate_param_doc(param) == expected_doc
def test_generate_return_doc_simple(self):
responses = {
'200': {
'description': 'Successful response',
'content': {'application/json': {'schema': {'type': 'string'}}},
}
}
expected_doc = 'Returns (str): Successful response'
assert (
PydocHelper.generate_return_doc(dict_to_responses(responses))
== expected_doc
)
def test_generate_return_doc_no_content(self):
responses = {'204': {'description': 'No content'}}
assert not PydocHelper.generate_return_doc(dict_to_responses(responses))
def test_generate_return_doc_object(self):
responses = {
'200': {
'description': 'Successful object response',
'content': {
'application/json': {
'schema': {
'type': 'object',
'properties': {
'prop1': {
'type': 'string',
'description': 'Prop1 desc',
},
'prop2': {'type': 'integer'},
},
}
}
},
}
}
return_doc = PydocHelper.generate_return_doc(dict_to_responses(responses))
assert 'Returns (Dict[str, Any]): Successful object response' in return_doc
assert 'prop1 (str): Prop1 desc' in return_doc
assert 'prop2 (int):' in return_doc
def test_generate_return_doc_multiple_success(self):
responses = {
'200': {
'description': 'Successful response',
'content': {'application/json': {'schema': {'type': 'string'}}},
},
'400': {'description': 'Bad request'},
}
expected_doc = 'Returns (str): Successful response'
assert (
PydocHelper.generate_return_doc(dict_to_responses(responses))
== expected_doc
)
def test_generate_return_doc_2xx_smallest_status_code_response(self):
responses = {
'201': {
'description': '201 response',
'content': {'application/json': {'schema': {'type': 'integer'}}},
},
'200': {
'description': '200 response',
'content': {'application/json': {'schema': {'type': 'string'}}},
},
'400': {'description': 'Bad request'},
}
expected_doc = 'Returns (str): 200 response'
assert (
PydocHelper.generate_return_doc(dict_to_responses(responses))
== expected_doc
)
def test_generate_return_doc_contentful_response(self):
responses = {
'200': {'description': 'No content response'},
'201': {
'description': '201 response',
'content': {'application/json': {'schema': {'type': 'string'}}},
},
'400': {'description': 'Bad request'},
}
expected_doc = 'Returns (str): 201 response'
assert (
PydocHelper.generate_return_doc(dict_to_responses(responses))
== expected_doc
)
if __name__ == '__main__':
pytest.main([__file__])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,628 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
import pytest
def create_minimal_openapi_spec() -> Dict[str, Any]:
"""Creates a minimal valid OpenAPI spec."""
return {
"openapi": "3.1.0",
"info": {"title": "Minimal API", "version": "1.0.0"},
"paths": {
"/test": {
"get": {
"summary": "Test GET endpoint",
"operationId": "testGet",
"responses": {
"200": {
"description": "Successful response",
"content": {
"application/json": {"schema": {"type": "string"}}
},
}
},
}
}
},
}
@pytest.fixture
def openapi_spec_generator():
"""Fixture for creating an OperationGenerator instance."""
return OpenApiSpecParser()
def test_parse_minimal_spec(openapi_spec_generator):
"""Test parsing a minimal OpenAPI specification."""
openapi_spec = create_minimal_openapi_spec()
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert len(parsed_operations) == 1
assert op.name == "test_get"
assert op.endpoint.path == "/test"
assert op.endpoint.method == "get"
assert op.return_value.type_value == str
def test_parse_spec_with_no_operation_id(openapi_spec_generator):
"""Test parsing a spec where operationId is missing (auto-generation)."""
openapi_spec = create_minimal_openapi_spec()
del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
# Check if operationId is auto generated based on path and method.
assert parsed_operations[0].name == "test_get"
def test_parse_spec_with_multiple_methods(openapi_spec_generator):
"""Test parsing a spec with multiple HTTP methods for the same path."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["post"] = {
"summary": "Test POST endpoint",
"operationId": "testPost",
"responses": {"200": {"description": "Successful response"}},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
operation_names = {op.name for op in parsed_operations}
assert len(parsed_operations) == 2
assert "test_get" in operation_names
assert "test_post" in operation_names
def test_parse_spec_with_parameters(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["get"]["parameters"] = [
{"name": "param1", "in": "query", "schema": {"type": "string"}},
{"name": "param2", "in": "header", "schema": {"type": "integer"}},
]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations[0].parameters) == 2
assert parsed_operations[0].parameters[0].original_name == "param1"
assert parsed_operations[0].parameters[0].param_location == "query"
assert parsed_operations[0].parameters[1].original_name == "param2"
assert parsed_operations[0].parameters[1].param_location == "header"
def test_parse_spec_with_request_body(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["post"] = {
"summary": "Endpoint with request body",
"operationId": "testPostWithBody",
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
}
},
"responses": {"200": {"description": "OK"}},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
post_operations = [
op for op in parsed_operations if op.endpoint.method == "post"
]
op = post_operations[0]
assert len(post_operations) == 1
assert op.name == "test_post_with_body"
assert len(op.parameters) == 1
assert op.parameters[0].original_name == "name"
assert op.parameters[0].type_value == str
def test_parse_spec_with_reference(openapi_spec_generator):
"""Test parsing a specification with $ref."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "API with Refs", "version": "1.0.0"},
"paths": {
"/test_ref": {
"get": {
"summary": "Endpoint with ref",
"operationId": "testGetRef",
"responses": {
"200": {
"description": "Success",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MySchema"
}
}
},
}
},
}
}
},
"components": {
"schemas": {
"MySchema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert len(parsed_operations) == 1
assert op.return_value.type_value.__origin__ is dict
def test_parse_spec_with_circular_reference(openapi_spec_generator):
"""Test correct handling of circular $ref (important!)."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Circular Ref API", "version": "1.0.0"},
"paths": {
"/circular": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/A"}
}
},
}
}
}
}
},
"components": {
"schemas": {
"A": {
"type": "object",
"properties": {"b": {"$ref": "#/components/schemas/B"}},
},
"B": {
"type": "object",
"properties": {"a": {"$ref": "#/components/schemas/A"}},
},
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
op = parsed_operations[0]
assert op.return_value.type_value.__origin__ is dict
assert op.return_value.type_hint == "Dict[str, Any]"
def test_parse_no_paths(openapi_spec_generator):
"""Test with a spec that has no paths defined."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "No Paths API", "version": "1.0.0"},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 0 # Should be empty
def test_parse_empty_path_item(openapi_spec_generator):
"""Test a path item that is present but empty."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Empty Path Item API", "version": "1.0.0"},
"paths": {"/empty": None},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 0
def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
"""Test parsing with a global security scheme."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["security"] = [{"api_key": []}]
openapi_spec["components"] = {
"securitySchemes": {
"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
}
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert len(parsed_operations) == 1
assert op.auth_scheme is not None
assert op.auth_scheme.type_.value == "apiKey"
def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
"""Test parsing with a local (operation-level) security scheme."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
openapi_spec["components"] = {
"securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert op.auth_scheme is not None
assert op.auth_scheme.type_.value == "http"
assert op.auth_scheme.scheme == "bearer"
def test_parse_spec_with_servers(openapi_spec_generator):
"""Test parsing with server URLs."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["servers"] = [
{"url": "https://api.example.com"},
{"url": "http://localhost:8000"},
]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
def test_parse_spec_with_no_servers(openapi_spec_generator):
"""Test with no servers defined (should default to empty string)."""
openapi_spec = create_minimal_openapi_spec()
if "servers" in openapi_spec:
del openapi_spec["servers"]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].endpoint.base_url == ""
def test_parse_spec_with_description(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
expected_description = "This is a test description."
openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].description == expected_description
def test_parse_spec_with_empty_description(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["get"]["description"] = ""
openapi_spec["paths"]["/test"]["get"]["summary"] = ""
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].description == ""
def test_parse_spec_with_no_description(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
# delete description
if "description" in openapi_spec["paths"]["/test"]["get"]:
del openapi_spec["paths"]["/test"]["get"]["description"]
if "summary" in openapi_spec["paths"]["/test"]["get"]:
del openapi_spec["paths"]["/test"]["get"]["summary"]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert (
parsed_operations[0].description == ""
) # it should be initialized with empty string
def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
"""Test that passing a non-dict object to parse raises TypeError"""
with pytest.raises(AttributeError):
openapi_spec_generator.parse(123) # type: ignore
with pytest.raises(AttributeError):
openapi_spec_generator.parse("openapi_spec") # type: ignore
with pytest.raises(AttributeError):
openapi_spec_generator.parse([]) # type: ignore
def test_parse_external_ref_raises_error(openapi_spec_generator):
"""Check that external references (not starting with #) raise ValueError."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "External Ref API", "version": "1.0.0"},
"paths": {
"/external": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": (
"external_file.json#/components/schemas/ExternalSchema"
)
}
}
},
}
}
}
}
},
}
with pytest.raises(ValueError):
openapi_spec_generator.parse(openapi_spec)
def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
"""Test specs with multiple paths, request/response bodies using deep refs."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
"paths": {
"/path1": {
"post": {
"operationId": "postPath1",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Request1"
}
}
}
},
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Response1"
}
}
},
}
},
}
},
"/path2": {
"put": {
"operationId": "putPath2",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Request2"
}
}
}
},
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Response2"
}
}
},
}
},
},
"get": {
"operationId": "getPath2",
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Response2"
}
}
},
}
},
},
},
},
"components": {
"schemas": {
"Request1": {
"type": "object",
"properties": {
"req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
},
},
"Response1": {
"type": "object",
"properties": {
"res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
},
},
"Request2": {
"type": "object",
"properties": {
"req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
},
},
"Response2": {
"type": "object",
"properties": {
"res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
},
},
"Level1_1": {
"type": "object",
"properties": {
"level1_1_prop1": {
"$ref": "#/components/schemas/Level2_1"
}
},
},
"Level1_2": {
"type": "object",
"properties": {
"level1_2_prop1": {
"$ref": "#/components/schemas/Level2_2"
}
},
},
"Level2_1": {
"type": "object",
"properties": {
"level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
},
},
"Level2_2": {
"type": "object",
"properties": {"level2_2_prop1": {"type": "string"}},
},
"Level3": {"type": "integer"},
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 3
# Verify Path 1
path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
assert len(path1_ops) == 1
path1_op = path1_ops[0]
assert path1_op.name == "post_path1"
assert len(path1_op.parameters) == 1
assert path1_op.parameters[0].original_name == "req1_prop1"
assert (
path1_op.parameters[0]
.param_schema.properties["level1_1_prop1"]
.properties["level2_1_prop1"]
.type
== "integer"
)
assert (
path1_op.return_value.param_schema.properties["res1_prop1"]
.properties["level1_2_prop1"]
.properties["level2_2_prop1"]
.type
== "string"
)
# Verify Path 2
path2_ops = [
op
for op in parsed_operations
if op.endpoint.path == "/path2" and op.name == "put_path2"
]
path2_op = path2_ops[0]
assert path2_op is not None
assert len(path2_op.parameters) == 1
assert path2_op.parameters[0].original_name == "req2_prop1"
assert (
path2_op.parameters[0]
.param_schema.properties["level1_1_prop1"]
.properties["level2_1_prop1"]
.type
== "integer"
)
assert (
path2_op.return_value.param_schema.properties["res2_prop1"]
.properties["level1_2_prop1"]
.properties["level2_2_prop1"]
.type
== "string"
)
def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
"""Test handling of duplicate parameter names (one in query, one in body).
The expected behavior is that both parameters should be captured but with
different suffix, and
their `original_name` attributes should reflect their origin (query or body).
"""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
"paths": {
"/duplicate": {
"post": {
"operationId": "createWithDuplicate",
"parameters": [{
"name": "name",
"in": "query",
"schema": {"type": "string"},
}],
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"name": {"type": "integer"}},
}
}
}
},
"responses": {"200": {"description": "OK"}},
}
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
op = parsed_operations[0]
assert op.name == "create_with_duplicate"
assert len(op.parameters) == 2
query_param = None
body_param = None
for param in op.parameters:
if param.param_location == "query" and param.original_name == "name":
query_param = param
elif param.param_location == "body" and param.original_name == "name":
body_param = param
assert query_param is not None
assert query_param.original_name == "name"
assert query_param.py_name == "name"
assert body_param is not None
assert body_param.original_name == "name"
assert body_param.py_name == "name_0"

View File

@@ -0,0 +1,139 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import ParameterInType
from fastapi.openapi.models import SecuritySchemeType
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
import pytest
import yaml
def load_spec(file_path: str) -> Dict:
"""Loads the OpenAPI specification from a YAML file."""
with open(file_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
@pytest.fixture
def openapi_spec() -> Dict:
"""Fixture to load the OpenAPI specification."""
current_dir = os.path.dirname(os.path.abspath(__file__))
# Join the directory path with the filename
yaml_path = os.path.join(current_dir, "test.yaml")
return load_spec(yaml_path)
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
"""Test initialization of OpenAPIToolset with a dictionary."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
assert isinstance(toolset.tools, list)
assert len(toolset.tools) == 5
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
"""Test initialization of OpenAPIToolset with a YAML string."""
spec_str = yaml.dump(openapi_spec)
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
assert isinstance(toolset.tools, list)
assert len(toolset.tools) == 5
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
def test_openapi_toolset_tool_existing(openapi_spec: Dict):
"""Test the tool() method for an existing tool."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
tool_name = "calendar_calendars_insert" # Example operationId from the spec
tool = toolset.get_tool(tool_name)
assert isinstance(tool, RestApiTool)
assert tool.name == tool_name
assert tool.description == "Creates a secondary calendar."
assert tool.endpoint.method == "post"
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
assert tool.endpoint.path == "/calendars"
assert tool.is_long_running is False
assert tool.operation.operationId == "calendar.calendars.insert"
assert tool.operation.description == "Creates a secondary calendar."
assert isinstance(
tool.operation.requestBody.content["application/json"], MediaType
)
assert len(tool.operation.responses) == 1
response = tool.operation.responses["200"]
assert response.description == "Successful response"
assert isinstance(response.content["application/json"], MediaType)
assert isinstance(tool.auth_scheme, OAuth2)
tool_name = "calendar_calendars_get"
tool = toolset.get_tool(tool_name)
assert isinstance(tool, RestApiTool)
assert tool.name == tool_name
assert tool.description == "Returns metadata for a calendar."
assert tool.endpoint.method == "get"
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
assert tool.endpoint.path == "/calendars/{calendarId}"
assert tool.is_long_running is False
assert tool.operation.operationId == "calendar.calendars.get"
assert tool.operation.description == "Returns metadata for a calendar."
assert len(tool.operation.parameters) == 1
assert tool.operation.parameters[0].name == "calendarId"
assert tool.operation.parameters[0].in_ == ParameterInType.path
assert tool.operation.parameters[0].required is True
assert tool.operation.parameters[0].schema_.type == "string"
assert (
tool.operation.parameters[0].description
== "Calendar identifier. To retrieve calendar IDs call the"
" calendarList.list method. If you want to access the primary calendar"
' of the currently logged in user, use the "primary" keyword.'
)
assert isinstance(tool.auth_scheme, OAuth2)
assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
"""Test the tool() method for a non-existing tool."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
tool = toolset.get_tool("non_existent_tool")
assert tool is None
def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
"""Test configuring auth during initialization."""
auth_scheme = APIKey(**{
"in": APIKeyIn.header, # Use alias name in dict
"name": "api_key",
"type": SecuritySchemeType.http,
})
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
toolset = OpenAPIToolset(
spec_dict=openapi_spec,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
for tool in toolset.tools:
assert tool.auth_scheme == auth_scheme
assert tool.auth_credential == auth_credential

View File

@@ -0,0 +1,406 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Response
from fastapi.openapi.models import Schema
from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
import pytest
@pytest.fixture
def sample_operation() -> Operation:
"""Fixture to provide a sample OpenAPI Operation object."""
return Operation(
operationId='test_operation',
summary='Test Summary',
description='Test Description',
parameters=[
Parameter(**{
'name': 'param1',
'in': 'query',
'schema': Schema(type='string'),
'description': 'Parameter 1',
}),
Parameter(**{
'name': 'param2',
'in': 'header',
'schema': Schema(type='string'),
'description': 'Parameter 2',
}),
],
requestBody=RequestBody(
content={
'application/json': MediaType(
schema=Schema(
type='object',
properties={
'prop1': Schema(
type='string', description='Property 1'
),
'prop2': Schema(
type='integer', description='Property 2'
),
},
)
)
},
description='Request body description',
),
responses={
'200': Response(
description='Success',
content={
'application/json': MediaType(schema=Schema(type='string'))
},
),
'400': Response(description='Client Error'),
},
security=[{'oauth2': ['resource: read', 'resource: write']}],
)
def test_operation_parser_initialization(sample_operation):
"""Test initialization of OperationParser."""
parser = OperationParser(sample_operation)
assert parser.operation == sample_operation
assert len(parser.params) == 4 # 2 params + 2 request body props
assert parser.return_value is not None
def test_process_operation_parameters(sample_operation):
"""Test _process_operation_parameters method."""
parser = OperationParser(sample_operation, should_parse=False)
parser._process_operation_parameters()
assert len(parser.params) == 2
assert parser.params[0].original_name == 'param1'
assert parser.params[0].param_location == 'query'
assert parser.params[1].original_name == 'param2'
assert parser.params[1].param_location == 'header'
def test_process_request_body(sample_operation):
"""Test _process_request_body method."""
parser = OperationParser(sample_operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 2 # 2 properties in request body
assert parser.params[0].original_name == 'prop1'
assert parser.params[0].param_location == 'body'
assert parser.params[1].original_name == 'prop2'
assert parser.params[1].param_location == 'body'
def test_process_request_body_array():
"""Test _process_request_body method with array schema."""
operation = Operation(
requestBody=RequestBody(
content={
'application/json': MediaType(
schema=Schema(
type='array',
items=Schema(
type='object',
properties={
'item_prop1': Schema(
type='string', description='Item Property 1'
),
'item_prop2': Schema(
type='integer', description='Item Property 2'
),
},
),
)
)
}
)
)
parser = OperationParser(operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 1
assert parser.params[0].original_name == 'array'
assert parser.params[0].param_location == 'body'
# Check that schema is correctly propagated and is a dictionary
assert parser.params[0].param_schema.type == 'array'
assert parser.params[0].param_schema.items.type == 'object'
assert 'item_prop1' in parser.params[0].param_schema.items.properties
assert 'item_prop2' in parser.params[0].param_schema.items.properties
assert (
parser.params[0].param_schema.items.properties['item_prop1'].description
== 'Item Property 1'
)
assert (
parser.params[0].param_schema.items.properties['item_prop2'].description
== 'Item Property 2'
)
def test_process_request_body_no_name():
"""Test _process_request_body with a schema that has no properties (unnamed)"""
operation = Operation(
requestBody=RequestBody(
content={'application/json': MediaType(schema=Schema(type='string'))}
)
)
parser = OperationParser(operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 1
assert parser.params[0].original_name == '' # No name
assert parser.params[0].param_location == 'body'
def test_dedupe_param_names(sample_operation):
"""Test _dedupe_param_names method."""
parser = OperationParser(sample_operation, should_parse=False)
# Add duplicate named parameters.
parser.params = [
ApiParameter(original_name='test', param_location='', param_schema={}),
ApiParameter(original_name='test', param_location='', param_schema={}),
ApiParameter(original_name='test', param_location='', param_schema={}),
]
parser._dedupe_param_names()
assert parser.params[0].py_name == 'test'
assert parser.params[1].py_name == 'test_0'
assert parser.params[2].py_name == 'test_1'
def test_process_return_value(sample_operation):
"""Test _process_return_value method."""
parser = OperationParser(sample_operation, should_parse=False)
parser._process_return_value()
assert parser.return_value is not None
assert parser.return_value.type_hint == 'str'
def test_process_return_value_no_2xx(sample_operation):
"""Tests _process_return_value when no 2xx response exists."""
operation_no_2xx = Operation(
responses={'400': Response(description='Client Error')}
)
parser = OperationParser(operation_no_2xx, should_parse=False)
parser._process_return_value()
assert parser.return_value is not None
assert parser.return_value.type_hint == 'Any'
def test_process_return_value_multiple_2xx(sample_operation):
"""Tests _process_return_value when multiple 2xx responses exist."""
operation_multi_2xx = Operation(
responses={
'201': Response(
description='Success',
content={
'application/json': MediaType(schema=Schema(type='integer'))
},
),
'202': Response(
description='Success',
content={'text/plain': MediaType(schema=Schema(type='string'))},
),
'200': Response(
description='Success',
content={
'application/pdf': MediaType(schema=Schema(type='boolean'))
},
),
'400': Response(
description='Failure',
content={
'application/xml': MediaType(schema=Schema(type='object'))
},
),
}
)
parser = OperationParser(operation_multi_2xx, should_parse=False)
parser._process_return_value()
assert parser.return_value is not None
# Take the content type of the 200 response since it's the smallest response
# code
assert parser.return_value.param_schema.type == 'boolean'
def test_process_return_value_no_content(sample_operation):
"""Test when 2xx response has no content"""
operation_no_content = Operation(
responses={'200': Response(description='Success', content={})}
)
parser = OperationParser(operation_no_content, should_parse=False)
parser._process_return_value()
assert parser.return_value.type_hint == 'Any'
def test_process_return_value_no_schema(sample_operation):
"""Tests when the 2xx response's content has no schema."""
operation_no_schema = Operation(
responses={
'200': Response(
description='Success',
content={'application/json': MediaType(schema=None)},
)
}
)
parser = OperationParser(operation_no_schema, should_parse=False)
parser._process_return_value()
assert parser.return_value.type_hint == 'Any'
def test_get_function_name(sample_operation):
"""Test get_function_name method."""
parser = OperationParser(sample_operation)
assert parser.get_function_name() == 'test_operation'
def test_get_function_name_missing_id():
"""Tests get_function_name when operationId is missing"""
operation = Operation() # No ID
parser = OperationParser(operation)
with pytest.raises(ValueError, match='Operation ID is missing'):
parser.get_function_name()
def test_get_return_type_hint(sample_operation):
"""Test get_return_type_hint method."""
parser = OperationParser(sample_operation)
assert parser.get_return_type_hint() == 'str'
def test_get_return_type_value(sample_operation):
"""Test get_return_type_value method."""
parser = OperationParser(sample_operation)
assert parser.get_return_type_value() == str
def test_get_parameters(sample_operation):
"""Test get_parameters method."""
parser = OperationParser(sample_operation)
params = parser.get_parameters()
assert len(params) == 4 # Correct count after processing
assert all(isinstance(p, ApiParameter) for p in params)
def test_get_return_value(sample_operation):
"""Test get_return_value method."""
parser = OperationParser(sample_operation)
return_value = parser.get_return_value()
assert isinstance(return_value, ApiParameter)
def test_get_auth_scheme_name(sample_operation):
"""Test get_auth_scheme_name method."""
parser = OperationParser(sample_operation)
assert parser.get_auth_scheme_name() == 'oauth2'
def test_get_auth_scheme_name_no_security():
"""Test get_auth_scheme_name when no security is present."""
operation = Operation(responses={})
parser = OperationParser(operation)
assert parser.get_auth_scheme_name() == ''
def test_get_pydoc_string(sample_operation):
"""Test get_pydoc_string method."""
parser = OperationParser(sample_operation)
pydoc_string = parser.get_pydoc_string()
assert 'Test Summary' in pydoc_string
assert 'Args:' in pydoc_string
assert 'param1 (str): Parameter 1' in pydoc_string
assert 'prop1 (str): Property 1' in pydoc_string
assert 'Returns (str):' in pydoc_string
assert 'Success' in pydoc_string
def test_get_json_schema(sample_operation):
"""Test get_json_schema method."""
parser = OperationParser(sample_operation)
json_schema = parser.get_json_schema()
assert json_schema['title'] == 'test_operation_Arguments'
assert json_schema['type'] == 'object'
assert 'param1' in json_schema['properties']
assert 'prop1' in json_schema['properties']
assert 'param1' in json_schema['required']
assert 'prop1' in json_schema['required']
def test_get_signature_parameters(sample_operation):
"""Test get_signature_parameters method."""
parser = OperationParser(sample_operation)
signature_params = parser.get_signature_parameters()
assert len(signature_params) == 4
assert signature_params[0].name == 'param1'
assert signature_params[0].annotation == str
assert signature_params[2].name == 'prop1'
assert signature_params[2].annotation == str
def test_get_annotations(sample_operation):
"""Test get_annotations method."""
parser = OperationParser(sample_operation)
annotations = parser.get_annotations()
assert len(annotations) == 5 # 4 parameters + return
assert annotations['param1'] == str
assert annotations['prop1'] == str
assert annotations['return'] == str
def test_load():
"""Test the load classmethod."""
operation = Operation(operationId='my_op') # Minimal operation
params = [
ApiParameter(
original_name='p1',
param_location='',
param_schema={'type': 'integer'},
)
]
return_value = ApiParameter(
original_name='', param_location='', param_schema={'type': 'string'}
)
parser = OperationParser.load(operation, params, return_value)
assert isinstance(parser, OperationParser)
assert parser.operation == operation
assert parser.params == params
assert parser.return_value == return_value
assert (
parser.get_function_name() == 'my_op'
) # Check that the operation is loaded
def test_operation_parser_with_dict():
"""Test initialization of OperationParser with a dictionary."""
operation_dict = {
'operationId': 'test_dict_operation',
'parameters': [
{'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
],
'responses': {
'200': {
'description': 'Dict Success',
'content': {'application/json': {'schema': {'type': 'string'}}},
}
},
}
parser = OperationParser(operation_dict)
assert parser.operation.operationId == 'test_dict_operation'
assert len(parser.params) == 1
assert parser.params[0].original_name == 'dict_param'
assert parser.return_value.type_hint == 'str'

View File

@@ -0,0 +1,966 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest.mock import MagicMock
from unittest.mock import patch
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter as OpenAPIParameter
from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Schema as OpenAPISchema
from google.adk.sessions.state import State
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Type
import pytest
class TestRestApiTool:
@pytest.fixture
def mock_tool_context(self):
"""Fixture for a mock OperationParser."""
mock_context = MagicMock(spec=ToolContext)
mock_context.state = State({}, {})
mock_context.get_auth_response.return_value = {}
mock_context.request_credential.return_value = {}
return mock_context
@pytest.fixture
def mock_operation_parser(self):
"""Fixture for a mock OperationParser."""
mock_parser = MagicMock(spec=OperationParser)
mock_parser.get_function_name.return_value = "mock_function_name"
mock_parser.get_json_schema.return_value = {}
mock_parser.get_parameters.return_value = []
mock_parser.get_return_type_hint.return_value = "str"
mock_parser.get_pydoc_string.return_value = "Mock docstring"
mock_parser.get_signature_parameters.return_value = []
mock_parser.get_return_type_value.return_value = str
mock_parser.get_annotations.return_value = {}
return mock_parser
@pytest.fixture
def sample_endpiont(self):
return OperationEndpoint(
base_url="https://example.com", path="/test", method="GET"
)
@pytest.fixture
def sample_operation(self):
return Operation(
operationId="testOperation",
description="Test operation",
parameters=[],
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(
type="object",
properties={
"testBodyParam": OpenAPISchema(type="string")
},
)
)
}
),
)
@pytest.fixture
def sample_api_parameters(self):
return [
ApiParameter(
original_name="test_param",
py_name="test_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
is_required=True,
),
ApiParameter(
original_name="",
py_name="test_body_param",
param_location="body",
param_schema=OpenAPISchema(type="string"),
is_required=True,
),
]
@pytest.fixture
def sample_return_parameter(self):
return ApiParameter(
original_name="test_param",
py_name="test_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
is_required=True,
)
@pytest.fixture
def sample_auth_scheme(self):
scheme, _ = token_to_scheme_credential(
"apikey", "header", "", "sample_auth_credential_internal_test"
)
return scheme
@pytest.fixture
def sample_auth_credential(self):
_, credential = token_to_scheme_credential(
"apikey", "header", "", "sample_auth_credential_internal_test"
)
return credential
def test_init(
self,
sample_endpiont,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)
assert tool.name == "test_tool"
assert tool.description == "Test Tool"
assert tool.endpoint == sample_endpiont
assert tool.operation == sample_operation
assert tool.auth_credential == sample_auth_credential
assert tool.auth_scheme == sample_auth_scheme
assert tool.credential_exchanger is not None
def test_from_parsed_operation_str(
self,
sample_endpiont,
sample_api_parameters,
sample_return_parameter,
sample_operation,
):
parsed_operation_str = json.dumps({
"name": "test_operation",
"description": "Test Description",
"endpoint": sample_endpiont.model_dump(),
"operation": sample_operation.model_dump(),
"auth_scheme": None,
"auth_credential": None,
"parameters": [p.model_dump() for p in sample_api_parameters],
"return_value": sample_return_parameter.model_dump(),
})
tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
assert tool.name == "test_operation"
def test_get_declaration(
self, sample_endpiont, sample_operation, mock_operation_parser
):
tool = RestApiTool(
name="test_tool",
description="Test description",
endpoint=sample_endpiont,
operation=sample_operation,
should_parse_operation=False,
)
tool._operation_parser = mock_operation_parser
declaration = tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_tool"
assert declaration.description == "Test description"
assert isinstance(declaration.parameters, Schema)
@patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
)
def test_call_success(
self,
mock_request,
mock_tool_context,
sample_endpiont,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)
# Call the method
result = tool.call(args={}, tool_context=mock_tool_context)
# Check the result
assert result == {"result": "success"}
@patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
)
def test_call_auth_pending(
self,
mock_request,
sample_endpiont,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)
with patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"pending"
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
response = tool.call(args={}, tool_context=None)
assert response == {
"pending": True,
"message": "Needs your authorization to access your data.",
}
def test_prepare_request_params_query_body(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
# Create a mock Operation object
mock_operation = Operation(
operationId="test_op",
parameters=[
OpenAPIParameter(**{
"name": "testQueryParam",
"in": "query",
"schema": OpenAPISchema(type="string"),
})
],
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(
type="object",
properties={
"param1": OpenAPISchema(type="string"),
"param2": OpenAPISchema(type="integer"),
},
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="param1",
py_name="param1",
param_location="body",
param_schema=OpenAPISchema(type="string"),
),
ApiParameter(
original_name="param2",
py_name="param2",
param_location="body",
param_schema=OpenAPISchema(type="integer"),
),
ApiParameter(
original_name="testQueryParam",
py_name="test_query_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
),
]
kwargs = {
"param1": "value1",
"param2": 123,
"test_query_param": "query_value",
}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["method"] == "get"
assert request_params["url"] == "https://example.com/test"
assert request_params["json"] == {"param1": "value1", "param2": 123}
assert request_params["params"] == {"testQueryParam": "query_value"}
def test_prepare_request_params_array(
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(
type="array", items=OpenAPISchema(type="string")
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="array", # Match the parameter name
py_name="array",
param_location="body",
param_schema=OpenAPISchema(
type="array", items=OpenAPISchema(type="string")
),
)
]
kwargs = {"array": ["item1", "item2"]}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["json"] == ["item1", "item2"]
def test_prepare_request_params_string(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"text/plain": MediaType(schema=OpenAPISchema(type="string"))
}
),
)
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="",
py_name="input_string",
param_location="body",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"input_string": "test_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["data"] == "test_value"
assert request_params["headers"]["Content-Type"] == "text/plain"
def test_prepare_request_params_form_data(
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/x-www-form-urlencoded": MediaType(
schema=OpenAPISchema(
type="object",
properties={"key1": OpenAPISchema(type="string")},
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="key1",
py_name="key1",
param_location="body",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"key1": "value1"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["data"] == {"key1": "value1"}
assert (
request_params["headers"]["Content-Type"]
== "application/x-www-form-urlencoded"
)
def test_prepare_request_params_multipart(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"multipart/form-data": MediaType(
schema=OpenAPISchema(
type="object",
properties={
"file1": OpenAPISchema(
type="string", format="binary"
)
},
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="file1",
py_name="file1",
param_location="body",
param_schema=OpenAPISchema(type="string", format="binary"),
)
]
kwargs = {"file1": b"file_content"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["files"] == {"file1": b"file_content"}
assert request_params["headers"]["Content-Type"] == "multipart/form-data"
def test_prepare_request_params_octet_stream(
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/octet-stream": MediaType(
schema=OpenAPISchema(type="string", format="binary")
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="",
py_name="data",
param_location="body",
param_schema=OpenAPISchema(type="string", format="binary"),
)
]
kwargs = {"data": b"binary_data"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["data"] == b"binary_data"
assert (
request_params["headers"]["Content-Type"] == "application/octet-stream"
)
def test_prepare_request_params_path_param(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
mock_operation = Operation(operationId="test_op")
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="user_id",
py_name="user_id",
param_location="path",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"user_id": "123"}
endpoint_with_path = OperationEndpoint(
base_url="https://example.com", path="/test/{user_id}", method="get"
)
tool.endpoint = endpoint_with_path
request_params = tool._prepare_request_params(params, kwargs)
assert (
request_params["url"] == "https://example.com/test/123"
) # Path param replaced
def test_prepare_request_params_header_param(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="X-Custom-Header",
py_name="x_custom_header",
param_location="header",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"x_custom_header": "header_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["headers"]["X-Custom-Header"] == "header_value"
def test_prepare_request_params_cookie_param(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="session_id",
py_name="session_id",
param_location="cookie",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"session_id": "cookie_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["cookies"]["session_id"] == "cookie_value"
def test_prepare_request_params_multiple_mime_types(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
# Test what happens when multiple mime types are specified. It should take
# the first one.
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(type="string")
),
"text/plain": MediaType(schema=OpenAPISchema(type="string")),
}
),
)
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="",
py_name="input",
param_location="body",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"input": "some_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["headers"]["Content-Type"] == "application/json"
def test_prepare_request_params_unknown_parameter(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="known_param",
py_name="known_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"known_param": "value", "unknown_param": "unknown"}
request_params = tool._prepare_request_params(params, kwargs)
# Make sure unknown parameters are ignored and do not raise errors.
assert "unknown_param" not in request_params["params"]
def test_prepare_request_params_base_url_handling(
self, sample_auth_credential, sample_auth_scheme, sample_operation
):
# No base_url provided, should use path as is
tool_no_base = RestApiTool(
name="test_tool_no_base",
description="Test Tool",
endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = []
kwargs = {}
request_params_no_base = tool_no_base._prepare_request_params(
params, kwargs
)
assert request_params_no_base["url"] == "/no_base"
tool_trailing_slash = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=OperationEndpoint(
base_url="https://example.com/", path="/trailing", method="get"
),
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
request_params_trailing = tool_trailing_slash._prepare_request_params(
params, kwargs
)
assert request_params_trailing["url"] == "https://example.com/trailing"
def test_prepare_request_params_no_unrecognized_query_parameter(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="unrecognized_param",
py_name="unrecognized_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"unrecognized_param": None} # Explicitly passing None
request_params = tool._prepare_request_params(params, kwargs)
# Query param not in sample_operation. It should be ignored.
assert "unrecognized_param" not in request_params["params"]
def test_prepare_request_params_no_credential(
self,
sample_endpiont,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=None,
auth_scheme=None,
)
params = [
ApiParameter(
original_name="param_name",
py_name="param_name",
param_location="query",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"param_name": "aaa", "empty_param": ""}
request_params = tool._prepare_request_params(params, kwargs)
assert "param_name" in request_params["params"]
assert "empty_param" not in request_params["params"]
class TestToGeminiSchema:
def test_to_gemini_schema_none(self):
assert to_gemini_schema(None) is None
def test_to_gemini_schema_not_dict(self):
with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
to_gemini_schema("not a dict")
def test_to_gemini_schema_empty_dict(self):
result = to_gemini_schema({})
assert isinstance(result, Schema)
assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
def test_to_gemini_schema_dict_with_only_object_type(self):
result = to_gemini_schema({"type": "object"})
assert isinstance(result, Schema)
assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
def test_to_gemini_schema_basic_types(self):
openapi_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"is_active": {"type": "boolean"},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert isinstance(gemini_schema, Schema)
assert gemini_schema.type == Type.OBJECT
assert gemini_schema.properties["name"].type == Type.STRING
assert gemini_schema.properties["age"].type == Type.INTEGER
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
def test_to_gemini_schema_nested_objects(self):
openapi_schema = {
"type": "object",
"properties": {
"address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
},
}
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.properties["address"].type == Type.OBJECT
assert (
gemini_schema.properties["address"].properties["street"].type
== Type.STRING
)
assert (
gemini_schema.properties["address"].properties["city"].type
== Type.STRING
)
def test_to_gemini_schema_array(self):
openapi_schema = {
"type": "array",
"items": {"type": "string"},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.type == Type.ARRAY
assert gemini_schema.items.type == Type.STRING
def test_to_gemini_schema_nested_array(self):
openapi_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.items.properties["name"].type == Type.STRING
def test_to_gemini_schema_any_of(self):
openapi_schema = {
"anyOf": [{"type": "string"}, {"type": "integer"}],
}
gemini_schema = to_gemini_schema(openapi_schema)
assert len(gemini_schema.any_of) == 2
assert gemini_schema.any_of[0].type == Type.STRING
assert gemini_schema.any_of[1].type == Type.INTEGER
def test_to_gemini_schema_general_list(self):
openapi_schema = {
"type": "array",
"properties": {
"list_field": {"type": "array", "items": {"type": "string"}},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.properties["list_field"].type == Type.ARRAY
assert gemini_schema.properties["list_field"].items.type == Type.STRING
def test_to_gemini_schema_enum(self):
openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.enum == ["a", "b", "c"]
def test_to_gemini_schema_required(self):
openapi_schema = {
"type": "object",
"required": ["name"],
"properties": {"name": {"type": "string"}},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.required == ["name"]
def test_to_gemini_schema_nested_dict(self):
openapi_schema = {
"type": "object",
"properties": {"metadata": {"key1": "value1", "key2": 123}},
}
gemini_schema = to_gemini_schema(openapi_schema)
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
assert isinstance(gemini_schema.properties["metadata"], Schema)
assert (
gemini_schema.properties["metadata"].type == Type.OBJECT
) # add object type by default
assert gemini_schema.properties["metadata"].properties == {
"dummy_DO_NOT_GENERATE": Schema(type="string")
}
def test_to_gemini_schema_ignore_title_default_format(self):
openapi_schema = {
"type": "string",
"title": "Test Title",
"default": "default_value",
"format": "date",
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.title is None
assert gemini_schema.default is None
assert gemini_schema.format is None
def test_to_gemini_schema_property_ordering(self):
openapi_schema = {
"type": "object",
"propertyOrdering": ["name", "age"],
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.property_ordering == ["name", "age"]
def test_to_gemini_schema_converts_property_dict(self):
openapi_schema = {
"properties": {
"name": {"type": "string", "description": "The property key"},
"value": {"type": "string", "description": "The property value"},
},
"type": "object",
"description": "A single property entry in the Properties message.",
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.type == Type.OBJECT
assert gemini_schema.properties["name"].type == Type.STRING
assert gemini_schema.properties["value"].type == Type.STRING
def test_to_gemini_schema_remove_unrecognized_fields(self):
openapi_schema = {
"type": "string",
"description": "A single date string.",
"format": "date",
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.type == Type.STRING
assert not gemini_schema.format
def test_snake_to_lower_camel():
assert snake_to_lower_camel("single") == "single"
assert snake_to_lower_camel("two_words") == "twoWords"
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
assert not snake_to_lower_camel("")
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"

View File

@@ -0,0 +1,201 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest.mock import MagicMock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
from google.adk.tools.tool_context import ToolContext
import pytest
# Helper function to create a mock ToolContext
def create_mock_tool_context():
return ToolContext(
function_call_id='test-fc-id',
invocation_context=InvocationContext(
agent=LlmAgent(name='test'),
session=Session(app_name='test', user_id='123', id='123'),
invocation_id='123',
session_service=InMemorySessionService(),
),
)
# Test cases for OpenID Connect
class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
def __init__(
self, expected_scheme, expected_credential, expected_access_token
):
self.expected_scheme = expected_scheme
self.expected_credential = expected_credential
self.expected_access_token = expected_access_token
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
if auth_credential.oauth2 and (
auth_credential.oauth2.auth_response_uri
or auth_credential.oauth2.auth_code
):
auth_code = (
auth_credential.oauth2.auth_response_uri
if auth_credential.oauth2.auth_response_uri
else auth_credential.oauth2.auth_code
)
# Simulate the token exchange
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme='bearer',
credentials=HttpCredentials(
token=auth_code + self.expected_access_token
),
),
)
return updated_credential
# simulate the case of getting auth_uri
return None
def get_mock_openid_scheme_credential():
config_dict = {
'authorization_endpoint': 'test.com',
'token_endpoint': 'test.com',
}
scopes = ['test_scope']
credential_dict = {
'client_id': '123',
'client_secret': '456',
'redirect_uri': 'test.com',
}
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
# Fixture for the OpenID Connect security scheme
@pytest.fixture
def openid_connect_scheme():
scheme, _ = get_mock_openid_scheme_credential()
return scheme
# Fixture for a base OpenID Connect credential
@pytest.fixture
def openid_connect_credential():
_, credential = get_mock_openid_scheme_credential()
return credential
def test_openid_connect_no_auth_response(
openid_connect_scheme, openid_connect_credential
):
# Setup Mock exchanger
mock_exchanger = MockOpenIdConnectCredentialExchanger(
openid_connect_scheme, openid_connect_credential, None
)
tool_context = create_mock_tool_context()
credential_store = ToolContextCredentialStore(tool_context=tool_context)
handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_exchanger=mock_exchanger,
credential_store=credential_store,
)
result = handler.prepare_auth_credentials()
assert result.state == 'pending'
assert result.auth_credential == openid_connect_credential
def test_openid_connect_with_auth_response(
openid_connect_scheme, openid_connect_credential, monkeypatch
):
mock_exchanger = MockOpenIdConnectCredentialExchanger(
openid_connect_scheme,
openid_connect_credential,
'test_access_token',
)
tool_context = create_mock_tool_context()
mock_auth_handler = MagicMock()
mock_auth_handler.get_auth_response.return_value = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
)
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
monkeypatch.setattr(
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
)
credential_store = ToolContextCredentialStore(tool_context=tool_context)
handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_exchanger=mock_exchanger,
credential_store=credential_store,
)
result = handler.prepare_auth_credentials()
assert result.state == 'done'
assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
assert 'test_access_token' in result.auth_credential.http.credentials.token
# Verify that the credential was stored:
stored_credential = credential_store.get_credential(
openid_connect_scheme, openid_connect_credential
)
assert stored_credential == result.auth_credential
mock_auth_handler.get_auth_response.assert_called_once()
def test_openid_connect_existing_token(
openid_connect_scheme, openid_connect_credential
):
_, existing_credential = token_to_scheme_credential(
'oauth2Token', 'header', 'bearer', '123123123'
)
tool_context = create_mock_tool_context()
# Store the credential to simulate existing credential
credential_store = ToolContextCredentialStore(tool_context=tool_context)
key = credential_store.get_credential_key(
openid_connect_scheme, openid_connect_credential
)
credential_store.store_credential(key, existing_credential)
handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_store=credential_store,
)
result = handler.prepare_auth_credentials()
assert result.state == 'done'
assert result.auth_credential == existing_credential