mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package. * Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py * Adds github workflow * minor fix in lite_llm.py for python 3.9. * format pyproject.toml
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for AutoAuthCredentialExchanger."""
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import pytest
|
||||
|
||||
|
||||
class MockCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Mock credential exchanger for testing."""
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Mock exchange credential method."""
|
||||
return auth_credential
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_exchanger():
|
||||
"""Fixture for creating an AutoAuthCredentialExchanger instance."""
|
||||
return AutoAuthCredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
"""Fixture for creating a mock AuthScheme instance."""
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
return scheme
|
||||
|
||||
|
||||
def test_init_with_custom_exchangers():
|
||||
"""Test initialization with custom exchangers."""
|
||||
custom_exchangers: Dict[str, Type[BaseAuthCredentialExchanger]] = {
|
||||
AuthCredentialTypes.API_KEY: MockCredentialExchanger
|
||||
}
|
||||
|
||||
auto_exchanger = AutoAuthCredentialExchanger(
|
||||
custom_exchangers=custom_exchangers
|
||||
)
|
||||
|
||||
assert (
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY]
|
||||
== MockCredentialExchanger
|
||||
)
|
||||
assert (
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT]
|
||||
== OAuth2CredentialExchanger
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_no_auth_credential(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with no auth_credential."""
|
||||
|
||||
assert auto_exchanger.exchange_credential(auth_scheme, None) is None
|
||||
|
||||
|
||||
def test_exchange_credential_no_exchange(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with NoExchangeCredentialExchanger."""
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == auth_credential
|
||||
|
||||
|
||||
def test_exchange_credential_open_id_connect(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with OpenID Connect scheme."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
)
|
||||
mock_exchanger = MagicMock(spec=OAuth2CredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "exchanged_credential"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "exchanged_credential"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_service_account(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with Service Account scheme."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
)
|
||||
mock_exchanger = MagicMock(spec=ServiceAccountCredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "exchanged_credential_sa"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.SERVICE_ACCOUNT] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "exchanged_credential_sa"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_custom_exchanger(auto_exchanger, auth_scheme):
|
||||
"""Test that exchange_credential calls the correct (custom) exchanger."""
|
||||
# Use a custom exchanger via the initialization
|
||||
mock_exchanger = MagicMock(spec=MockCredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "custom_credential"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "custom_credential"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the BaseAuthCredentialExchanger class."""
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
import pytest
|
||||
|
||||
|
||||
class MockAuthCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
return AuthCredential(token="some-token")
|
||||
|
||||
|
||||
class TestBaseAuthCredentialExchanger:
|
||||
"""Tests for the BaseAuthCredentialExchanger class."""
|
||||
|
||||
@pytest.fixture
|
||||
def base_exchanger(self):
|
||||
return BaseAuthCredentialExchanger()
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme(self):
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
scheme.type = "apiKey"
|
||||
scheme.name = "x-api-key"
|
||||
return scheme
|
||||
|
||||
def test_exchange_credential_not_implemented(
|
||||
self, base_exchanger, auth_scheme
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, token="some-token"
|
||||
)
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
base_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Subclasses must implement exchange_credential." in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_auth_credential_missing_error(self):
|
||||
error_message = "Test missing credential"
|
||||
error = AuthCredentialMissingError(error_message)
|
||||
# assert error.message == error_message
|
||||
assert str(error) == error_message
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for OAuth2CredentialExchanger."""
|
||||
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_exchanger():
|
||||
return OAuth2CredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
openid_config = OpenIdConnectWithConfig(
|
||||
type_=AuthSchemeType.openIdConnect,
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
token_endpoint="https://example.com/token",
|
||||
scopes=["openid", "profile"],
|
||||
)
|
||||
return openid_config
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
# Check that the method does not raise an exception
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_missing_credential(
|
||||
oauth2_exchanger, auth_scheme
|
||||
):
|
||||
# Test case: auth_credential is None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
|
||||
assert "auth_credential is empty" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_invalid_scheme_type(
|
||||
oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
|
||||
):
|
||||
"""Test case: Invalid AuthSchemeType."""
|
||||
# Test case: Invalid AuthSchemeType
|
||||
invalid_scheme = copy.deepcopy(auth_scheme)
|
||||
invalid_scheme.type_ = AuthSchemeType.apiKey
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(
|
||||
invalid_scheme, auth_credential
|
||||
)
|
||||
assert "Invalid security scheme" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_missing_openid_connect(
|
||||
oauth2_exchanger, auth_scheme
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
assert "auth_credential is not configured with oauth2" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_generate_auth_token_success(
|
||||
oauth2_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test case: Successful generation of access token."""
|
||||
# Test case: Successful generation of access token
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
auth_response_uri="https://example.com/callback?code=test_code",
|
||||
token={"access_token": "test_access_token"},
|
||||
),
|
||||
)
|
||||
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
|
||||
|
||||
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert updated_credential.http.scheme == "bearer"
|
||||
assert updated_credential.http.credentials.token == "test_access_token"
|
||||
|
||||
|
||||
def test_exchange_credential_generate_auth_token(
|
||||
oauth2_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test exchange_credential when auth_response_uri is present."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
auth_response_uri="https://example.com/callback?code=test_code",
|
||||
token={"access_token": "test_access_token"},
|
||||
),
|
||||
)
|
||||
|
||||
updated_credential = oauth2_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert updated_credential.http.scheme == "bearer"
|
||||
assert updated_credential.http.credentials.token == "test_access_token"
|
||||
|
||||
|
||||
def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
|
||||
"""Test exchange_credential when auth_credential is missing."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "auth_credential is empty. Please create AuthCredential using" in str(
|
||||
exc_info.value
|
||||
)
|
||||
@@ -0,0 +1,196 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the service account credential exchanger."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_credential import ServiceAccountCredential
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import google.auth
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_exchanger():
|
||||
return ServiceAccountCredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
scheme.type_ = AuthSchemeType.oauth2
|
||||
scheme.description = "Google Service Account"
|
||||
return scheme
|
||||
|
||||
|
||||
def test_exchange_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
|
||||
# Mock the from_service_account_info method
|
||||
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
# Mock the refresh method
|
||||
mock_credentials.refresh = MagicMock()
|
||||
|
||||
# Create a valid AuthCredential with service account info
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_use_default_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials using default credential."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
mock_google_auth_default = MagicMock(
|
||||
return_value=(mock_credentials, "test_project")
|
||||
)
|
||||
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_google_auth_default.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_missing_auth_credential(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing auth credential during exchange."""
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_missing_service_account_info(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing service account info during exchange."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_exchange_failure(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test failure during service account token exchange."""
|
||||
mock_from_service_account_info = MagicMock(
|
||||
side_effect=Exception("Failed to load credentials")
|
||||
)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Failed to exchange service account token" in str(exc_info.value)
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
573
tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
Normal file
573
tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OpenIdConnect
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_credential import ServiceAccountCredential
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import credential_to_param
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import INTERNAL_AUTH_PREFIX
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_url_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_header():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "X-API-Key", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.header
|
||||
assert scheme.name == "X-API-Key"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_query():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "query", "api_key", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.query
|
||||
assert scheme.name == "api_key"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_cookie():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "cookie", "session_id", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.cookie
|
||||
assert scheme.name == "session_id"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_no_credential():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "cookie", "session_id"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert credential is None
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_oauth2_token():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"oauth2Token", "header", "Authorization", "test_token"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_oauth2_no_credential():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"oauth2Token", "header", "Authorization"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert credential is None
|
||||
|
||||
|
||||
def test_service_account_dict_to_scheme_credential():
|
||||
config = {
|
||||
"type": "service_account",
|
||||
"project_id": "project_id",
|
||||
"private_key_id": "private_key_id",
|
||||
"private_key": "private_key",
|
||||
"client_email": "client_email",
|
||||
"client_id": "client_id",
|
||||
"auth_uri": "auth_uri",
|
||||
"token_uri": "token_uri",
|
||||
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
|
||||
"client_x509_cert_url": "client_x509_cert_url",
|
||||
"universe_domain": "universe_domain",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = service_account_dict_to_scheme_credential(config, scopes)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
assert credential.service_account.scopes == scopes
|
||||
assert (
|
||||
credential.service_account.service_account_credential.project_id
|
||||
== "project_id"
|
||||
)
|
||||
|
||||
|
||||
def test_service_account_scheme_credential():
|
||||
config = ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type="service_account",
|
||||
project_id="project_id",
|
||||
private_key_id="private_key_id",
|
||||
private_key="private_key",
|
||||
client_email="client_email",
|
||||
client_id="client_id",
|
||||
auth_uri="auth_uri",
|
||||
token_uri="token_uri",
|
||||
auth_provider_x509_cert_url="auth_provider_x509_cert_url",
|
||||
client_x509_cert_url="client_x509_cert_url",
|
||||
universe_domain="universe_domain",
|
||||
),
|
||||
scopes=["scope1", "scope2"],
|
||||
)
|
||||
|
||||
scheme, credential = service_account_scheme_credential(config)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
assert credential.service_account == config
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"openIdConnectUrl": "openid_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert scheme.authorization_endpoint == "auth_url"
|
||||
assert scheme.token_endpoint == "token_url"
|
||||
assert scheme.scopes == scopes
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_no_openid_url():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert scheme.openIdConnectUrl == ""
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_google_oauth_credential():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"openIdConnectUrl": "openid_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"web": {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_invalid_config():
|
||||
config_dict = {
|
||||
"invalid_field": "value",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid OpenID Connect configuration"):
|
||||
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_missing_credential_fields():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Missing required fields in credential_dict: client_secret",
|
||||
):
|
||||
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential(mock_get):
|
||||
mock_response = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"userinfo_endpoint": "userinfo_url",
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_url_to_scheme_credential(
|
||||
"openid_url", scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert scheme.authorization_endpoint == "auth_url"
|
||||
assert scheme.token_endpoint == "token_url"
|
||||
assert scheme.scopes == scopes
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
mock_get.assert_called_once_with("openid_url", timeout=10)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_no_openid_url(mock_get):
|
||||
mock_response = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"userinfo_endpoint": "userinfo_url",
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_url_to_scheme_credential(
|
||||
"openid_url", scopes, credential_dict
|
||||
)
|
||||
|
||||
assert scheme.openIdConnectUrl == "openid_url"
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_request_exception(mock_get):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Test Error")
|
||||
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to fetch OpenID configuration from openid_url"
|
||||
):
|
||||
openid_url_to_scheme_credential("openid_url", [], credential_dict)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_invalid_json(mock_get):
|
||||
mock_get.return_value.json.side_effect = ValueError("Invalid JSON")
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Invalid JSON response from OpenID configuration endpoint openid_url"
|
||||
),
|
||||
):
|
||||
openid_url_to_scheme_credential("openid_url", [], credential_dict)
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_header():
|
||||
auth_scheme = APIKey(
|
||||
**{"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "X-API-Key"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "X-API-Key": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_query():
|
||||
auth_scheme = APIKey(**{"type": "apiKey", "in": "query", "name": "api_key"})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "api_key"
|
||||
assert param.param_location == "query"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "api_key": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_cookie():
|
||||
auth_scheme = APIKey(
|
||||
**{"type": "apiKey", "in": "cookie", "name": "session_id"}
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "session_id"
|
||||
assert param.param_location == "cookie"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "session_id": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_http_bearer():
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_http_basic_not_supported():
|
||||
auth_scheme = HTTPBase(scheme="basic")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="password"),
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="Basic Authentication is not supported."
|
||||
):
|
||||
credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_credential_to_param_http_invalid_credentials_no_http():
|
||||
auth_scheme = HTTPBase(scheme="basic")
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HTTP auth credentials"):
|
||||
credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_credential_to_param_oauth2():
|
||||
auth_scheme = OAuth2(flows={})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_openid_connect():
|
||||
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_openid_no_credential():
|
||||
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, None)
|
||||
|
||||
assert param == None
|
||||
assert kwargs == None
|
||||
|
||||
|
||||
def test_credential_to_param_oauth2_no_credential():
|
||||
auth_scheme = OAuth2(flows={})
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, None)
|
||||
|
||||
assert param == None
|
||||
assert kwargs == None
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_api_key():
|
||||
data = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.header
|
||||
assert scheme.name == "X-API-Key"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_http_bearer():
|
||||
data = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.scheme == "bearer"
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_http_base():
|
||||
data = {"type": "http", "scheme": "basic"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, HTTPBase)
|
||||
assert scheme.scheme == "basic"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_oauth2():
|
||||
data = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://example.com/auth",
|
||||
"tokenUrl": "https://example.com/token",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, OAuth2)
|
||||
assert hasattr(scheme.flows, "authorizationCode")
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_openid_connect():
|
||||
data = {
|
||||
"type": "openIdConnect",
|
||||
"openIdConnectUrl": (
|
||||
"https://example.com/.well-known/openid-configuration"
|
||||
),
|
||||
}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnect)
|
||||
assert (
|
||||
scheme.openIdConnectUrl
|
||||
== "https://example.com/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_missing_type():
|
||||
data = {"in": "header", "name": "X-API-Key"}
|
||||
with pytest.raises(
|
||||
ValueError, match="Missing 'type' field in security scheme dictionary."
|
||||
):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_invalid_type():
|
||||
data = {"type": "invalid", "in": "header", "name": "X-API-Key"}
|
||||
with pytest.raises(ValueError, match="Invalid security scheme type: invalid"):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_invalid_data():
|
||||
data = {"type": "apiKey", "in": "header"} # Missing 'name'
|
||||
with pytest.raises(ValueError, match="Invalid security scheme data"):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
436
tests/unittests/tools/openapi_tool/common/test_common.py
Normal file
436
tests/unittests/tools/openapi_tool/common/test_common.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from fastapi.openapi.models import Response, Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.common.common import PydocHelper
|
||||
from google.adk.tools.openapi_tool.common.common import rename_python_keywords
|
||||
from google.adk.tools.openapi_tool.common.common import to_snake_case
|
||||
from google.adk.tools.openapi_tool.common.common import TypeHintHelper
|
||||
import pytest
|
||||
|
||||
|
||||
def dict_to_responses(input: Dict[str, Any]) -> Dict[str, Response]:
|
||||
return {k: Response.model_validate(input[k]) for k in input}
|
||||
|
||||
|
||||
class TestToSnakeCase:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_str, expected_output',
|
||||
[
|
||||
('lowerCamelCase', 'lower_camel_case'),
|
||||
('UpperCamelCase', 'upper_camel_case'),
|
||||
('space separated', 'space_separated'),
|
||||
('REST API', 'rest_api'),
|
||||
('Mixed_CASE with_Spaces', 'mixed_case_with_spaces'),
|
||||
('__init__', 'init'),
|
||||
('APIKey', 'api_key'),
|
||||
('SomeLongURL', 'some_long_url'),
|
||||
('CONSTANT_CASE', 'constant_case'),
|
||||
('already_snake_case', 'already_snake_case'),
|
||||
('single', 'single'),
|
||||
('', ''),
|
||||
(' spaced ', 'spaced'),
|
||||
('with123numbers', 'with123numbers'),
|
||||
('With_Mixed_123_and_SPACES', 'with_mixed_123_and_spaces'),
|
||||
('HTMLParser', 'html_parser'),
|
||||
('HTTPResponseCode', 'http_response_code'),
|
||||
('a_b_c', 'a_b_c'),
|
||||
('A_B_C', 'a_b_c'),
|
||||
('fromAtoB', 'from_ato_b'),
|
||||
('XMLHTTPRequest', 'xmlhttp_request'),
|
||||
('_leading', 'leading'),
|
||||
('trailing_', 'trailing'),
|
||||
(' leading_and_trailing_ ', 'leading_and_trailing'),
|
||||
('Multiple___Underscores', 'multiple_underscores'),
|
||||
(' spaces_and___underscores ', 'spaces_and_underscores'),
|
||||
(' _mixed_Case ', 'mixed_case'),
|
||||
('123Start', '123_start'),
|
||||
('End123', 'end123'),
|
||||
('Mid123dle', 'mid123dle'),
|
||||
],
|
||||
)
|
||||
def test_to_snake_case(self, input_str, expected_output):
|
||||
assert to_snake_case(input_str) == expected_output
|
||||
|
||||
|
||||
class TestRenamePythonKeywords:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_str, expected_output',
|
||||
[
|
||||
('in', 'param_in'),
|
||||
('for', 'param_for'),
|
||||
('class', 'param_class'),
|
||||
('normal', 'normal'),
|
||||
('param_if', 'param_if'),
|
||||
('', ''),
|
||||
],
|
||||
)
|
||||
def test_rename_python_keywords(self, input_str, expected_output):
|
||||
assert rename_python_keywords(input_str) == expected_output
|
||||
|
||||
|
||||
class TestApiParameter:
|
||||
|
||||
def test_api_parameter_initialization(self):
|
||||
schema = Schema(type='string', description='A string parameter')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
description='A string description',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.original_name == 'testParam'
|
||||
assert param.param_location == 'query'
|
||||
assert param.param_schema.type == 'string'
|
||||
assert param.param_schema.description == 'A string parameter'
|
||||
assert param.py_name == 'test_param'
|
||||
assert param.type_hint == 'str'
|
||||
assert param.type_value == str
|
||||
assert param.description == 'A string description'
|
||||
|
||||
def test_api_parameter_keyword_rename(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='in',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.py_name == 'param_in'
|
||||
|
||||
def test_api_parameter_custom_py_name(self):
|
||||
schema = Schema(type='integer')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
py_name='custom_name',
|
||||
)
|
||||
assert param.py_name == 'custom_name'
|
||||
|
||||
def test_api_parameter_str_representation(self):
|
||||
schema = Schema(type='number')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert str(param) == 'test_param: float'
|
||||
|
||||
def test_api_parameter_to_arg_string(self):
|
||||
schema = Schema(type='boolean')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.to_arg_string() == 'test_param=test_param'
|
||||
|
||||
def test_api_parameter_to_dict_property(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='path',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.to_dict_property() == '"test_param": test_param'
|
||||
|
||||
def test_api_parameter_model_serializer(self):
|
||||
schema = Schema(type='string', description='test description')
|
||||
param = ApiParameter(
|
||||
original_name='TestParam',
|
||||
param_location='path',
|
||||
param_schema=schema,
|
||||
py_name='test_param_custom',
|
||||
description='test description',
|
||||
)
|
||||
|
||||
serialized_param = param.model_dump(mode='json', exclude_none=True)
|
||||
|
||||
assert serialized_param == {
|
||||
'original_name': 'TestParam',
|
||||
'param_location': 'path',
|
||||
'param_schema': {'type': 'string', 'description': 'test description'},
|
||||
'description': 'test description',
|
||||
'py_name': 'test_param_custom',
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'schema, expected_type_value, expected_type_hint',
|
||||
[
|
||||
({'type': 'integer'}, int, 'int'),
|
||||
({'type': 'number'}, float, 'float'),
|
||||
({'type': 'boolean'}, bool, 'bool'),
|
||||
({'type': 'string'}, str, 'str'),
|
||||
(
|
||||
{'type': 'string', 'format': 'date'},
|
||||
str,
|
||||
'str',
|
||||
),
|
||||
(
|
||||
{'type': 'string', 'format': 'date-time'},
|
||||
str,
|
||||
'str',
|
||||
),
|
||||
(
|
||||
{'type': 'array', 'items': {'type': 'integer'}},
|
||||
List[int],
|
||||
'List[int]',
|
||||
),
|
||||
(
|
||||
{'type': 'array', 'items': {'type': 'string'}},
|
||||
List[str],
|
||||
'List[str]',
|
||||
),
|
||||
(
|
||||
{
|
||||
'type': 'array',
|
||||
'items': {'type': 'object'},
|
||||
},
|
||||
List[Dict[str, Any]],
|
||||
'List[Dict[str, Any]]',
|
||||
),
|
||||
({'type': 'object'}, Dict[str, Any], 'Dict[str, Any]'),
|
||||
({'type': 'unknown'}, Any, 'Any'),
|
||||
({}, Any, 'Any'),
|
||||
],
|
||||
)
|
||||
def test_api_parameter_type_hint_helper(
|
||||
self, schema, expected_type_value, expected_type_hint
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name='test', param_location='query', param_schema=schema
|
||||
)
|
||||
assert param.type_value == expected_type_value
|
||||
assert param.type_hint == expected_type_hint
|
||||
assert (
|
||||
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
|
||||
)
|
||||
|
||||
def test_api_parameter_description(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='param1',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='The description',
|
||||
)
|
||||
assert param.description == 'The description'
|
||||
|
||||
def test_api_parameter_description_use_schema_fallback(self):
|
||||
schema = Schema(type='string', description='The description')
|
||||
param = ApiParameter(
|
||||
original_name='param1',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.description == 'The description'
|
||||
|
||||
|
||||
class TestTypeHintHelper:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'schema, expected_type_value, expected_type_hint',
|
||||
[
|
||||
({'type': 'integer'}, int, 'int'),
|
||||
({'type': 'number'}, float, 'float'),
|
||||
({'type': 'string'}, str, 'str'),
|
||||
(
|
||||
{
|
||||
'type': 'array',
|
||||
'items': {'type': 'string'},
|
||||
},
|
||||
List[str],
|
||||
'List[str]',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_type_value_and_hint(
|
||||
self, schema, expected_type_value, expected_type_hint
|
||||
):
|
||||
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test parameter',
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
|
||||
)
|
||||
|
||||
|
||||
class TestPydocHelper:
|
||||
|
||||
def test_generate_param_doc_simple(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test description',
|
||||
)
|
||||
|
||||
expected_doc = 'test_param (str): Test description'
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_no_description(self):
|
||||
schema = Schema(type='integer')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
expected_doc = 'test_param (int): '
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_object(self):
|
||||
schema = Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': {'type': 'string', 'description': 'Prop1 desc'},
|
||||
'prop2': {'type': 'integer'},
|
||||
},
|
||||
)
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test object parameter',
|
||||
)
|
||||
expected_doc = (
|
||||
'test_param (Dict[str, Any]): Test object parameter Object'
|
||||
' properties:\n prop1 (str): Prop1 desc\n prop2'
|
||||
' (int): \n'
|
||||
)
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_object_no_properties(self):
|
||||
schema = Schema(type='object', description='A test schema')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='The description.',
|
||||
)
|
||||
expected_doc = 'test_param (Dict[str, Any]): The description.'
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_return_doc_simple(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
}
|
||||
expected_doc = 'Returns (str): Successful response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_no_content(self):
|
||||
responses = {'204': {'description': 'No content'}}
|
||||
assert not PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
|
||||
def test_generate_return_doc_object(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful object response',
|
||||
'content': {
|
||||
'application/json': {
|
||||
'schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'prop1': {
|
||||
'type': 'string',
|
||||
'description': 'Prop1 desc',
|
||||
},
|
||||
'prop2': {'type': 'integer'},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return_doc = PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
|
||||
assert 'Returns (Dict[str, Any]): Successful object response' in return_doc
|
||||
assert 'prop1 (str): Prop1 desc' in return_doc
|
||||
assert 'prop2 (int):' in return_doc
|
||||
|
||||
def test_generate_return_doc_multiple_success(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
expected_doc = 'Returns (str): Successful response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_2xx_smallest_status_code_response(self):
|
||||
responses = {
|
||||
'201': {
|
||||
'description': '201 response',
|
||||
'content': {'application/json': {'schema': {'type': 'integer'}}},
|
||||
},
|
||||
'200': {
|
||||
'description': '200 response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
|
||||
expected_doc = 'Returns (str): 200 response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_contentful_response(self):
|
||||
responses = {
|
||||
'200': {'description': 'No content response'},
|
||||
'201': {
|
||||
'description': '201 response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
expected_doc = 'Returns (str): 201 response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,628 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||
import pytest
|
||||
|
||||
|
||||
def create_minimal_openapi_spec() -> Dict[str, Any]:
|
||||
"""Creates a minimal valid OpenAPI spec."""
|
||||
return {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Minimal API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"summary": "Test GET endpoint",
|
||||
"operationId": "testGet",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {"schema": {"type": "string"}}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec_generator():
|
||||
"""Fixture for creating an OperationGenerator instance."""
|
||||
return OpenApiSpecParser()
|
||||
|
||||
|
||||
def test_parse_minimal_spec(openapi_spec_generator):
|
||||
"""Test parsing a minimal OpenAPI specification."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.name == "test_get"
|
||||
assert op.endpoint.path == "/test"
|
||||
assert op.endpoint.method == "get"
|
||||
assert op.return_value.type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_no_operation_id(openapi_spec_generator):
|
||||
"""Test parsing a spec where operationId is missing (auto-generation)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
# Check if operationId is auto generated based on path and method.
|
||||
assert parsed_operations[0].name == "test_get"
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_methods(openapi_spec_generator):
|
||||
"""Test parsing a spec with multiple HTTP methods for the same path."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Test POST endpoint",
|
||||
"operationId": "testPost",
|
||||
"responses": {"200": {"description": "Successful response"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
operation_names = {op.name for op in parsed_operations}
|
||||
|
||||
assert len(parsed_operations) == 2
|
||||
assert "test_get" in operation_names
|
||||
assert "test_post" in operation_names
|
||||
|
||||
|
||||
def test_parse_spec_with_parameters(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["parameters"] = [
|
||||
{"name": "param1", "in": "query", "schema": {"type": "string"}},
|
||||
{"name": "param2", "in": "header", "schema": {"type": "integer"}},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations[0].parameters) == 2
|
||||
assert parsed_operations[0].parameters[0].original_name == "param1"
|
||||
assert parsed_operations[0].parameters[0].param_location == "query"
|
||||
assert parsed_operations[0].parameters[1].original_name == "param2"
|
||||
assert parsed_operations[0].parameters[1].param_location == "header"
|
||||
|
||||
|
||||
def test_parse_spec_with_request_body(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Endpoint with request body",
|
||||
"operationId": "testPostWithBody",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
post_operations = [
|
||||
op for op in parsed_operations if op.endpoint.method == "post"
|
||||
]
|
||||
op = post_operations[0]
|
||||
|
||||
assert len(post_operations) == 1
|
||||
assert op.name == "test_post_with_body"
|
||||
assert len(op.parameters) == 1
|
||||
assert op.parameters[0].original_name == "name"
|
||||
assert op.parameters[0].type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_reference(openapi_spec_generator):
|
||||
"""Test parsing a specification with $ref."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "API with Refs", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test_ref": {
|
||||
"get": {
|
||||
"summary": "Endpoint with ref",
|
||||
"operationId": "testGetRef",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/MySchema"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"MySchema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
|
||||
|
||||
def test_parse_spec_with_circular_reference(openapi_spec_generator):
|
||||
"""Test correct handling of circular $ref (important!)."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Circular Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/circular": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/A"}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"A": {
|
||||
"type": "object",
|
||||
"properties": {"b": {"$ref": "#/components/schemas/B"}},
|
||||
},
|
||||
"B": {
|
||||
"type": "object",
|
||||
"properties": {"a": {"$ref": "#/components/schemas/A"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
|
||||
op = parsed_operations[0]
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
assert op.return_value.type_hint == "Dict[str, Any]"
|
||||
|
||||
|
||||
def test_parse_no_paths(openapi_spec_generator):
|
||||
"""Test with a spec that has no paths defined."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "No Paths API", "version": "1.0.0"},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 0 # Should be empty
|
||||
|
||||
|
||||
def test_parse_empty_path_item(openapi_spec_generator):
|
||||
"""Test a path item that is present but empty."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Empty Path Item API", "version": "1.0.0"},
|
||||
"paths": {"/empty": None},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 0
|
||||
|
||||
|
||||
def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a global security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["security"] = [{"api_key": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {
|
||||
"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "apiKey"
|
||||
|
||||
|
||||
def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a local (operation-level) security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "http"
|
||||
assert op.auth_scheme.scheme == "bearer"
|
||||
|
||||
|
||||
def test_parse_spec_with_servers(openapi_spec_generator):
|
||||
"""Test parsing with server URLs."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["servers"] = [
|
||||
{"url": "https://api.example.com"},
|
||||
{"url": "http://localhost:8000"},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
|
||||
|
||||
|
||||
def test_parse_spec_with_no_servers(openapi_spec_generator):
|
||||
"""Test with no servers defined (should default to empty string)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
if "servers" in openapi_spec:
|
||||
del openapi_spec["servers"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
expected_description = "This is a test description."
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == expected_description
|
||||
|
||||
|
||||
def test_parse_spec_with_empty_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = ""
|
||||
openapi_spec["paths"]["/test"]["get"]["summary"] = ""
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_no_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
# delete description
|
||||
if "description" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["description"]
|
||||
if "summary" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["summary"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert (
|
||||
parsed_operations[0].description == ""
|
||||
) # it should be initialized with empty string
|
||||
|
||||
|
||||
def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
|
||||
"""Test that passing a non-dict object to parse raises TypeError"""
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse(123) # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse("openapi_spec") # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse([]) # type: ignore
|
||||
|
||||
|
||||
def test_parse_external_ref_raises_error(openapi_spec_generator):
|
||||
"""Check that external references (not starting with #) raise ValueError."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "External Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/external": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"external_file.json#/components/schemas/ExternalSchema"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
|
||||
"""Test specs with multiple paths, request/response bodies using deep refs."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/path1": {
|
||||
"post": {
|
||||
"operationId": "postPath1",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request1"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response1"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/path2": {
|
||||
"put": {
|
||||
"operationId": "putPath2",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request2"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"get": {
|
||||
"operationId": "getPath2",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Request1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Request2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Level1_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_1_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_1"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level1_2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_2_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_2"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level2_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
|
||||
},
|
||||
},
|
||||
"Level2_2": {
|
||||
"type": "object",
|
||||
"properties": {"level2_2_prop1": {"type": "string"}},
|
||||
},
|
||||
"Level3": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 3
|
||||
|
||||
# Verify Path 1
|
||||
path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
|
||||
assert len(path1_ops) == 1
|
||||
path1_op = path1_ops[0]
|
||||
assert path1_op.name == "post_path1"
|
||||
|
||||
assert len(path1_op.parameters) == 1
|
||||
assert path1_op.parameters[0].original_name == "req1_prop1"
|
||||
assert (
|
||||
path1_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path1_op.return_value.param_schema.properties["res1_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
# Verify Path 2
|
||||
path2_ops = [
|
||||
op
|
||||
for op in parsed_operations
|
||||
if op.endpoint.path == "/path2" and op.name == "put_path2"
|
||||
]
|
||||
path2_op = path2_ops[0]
|
||||
assert path2_op is not None
|
||||
assert len(path2_op.parameters) == 1
|
||||
assert path2_op.parameters[0].original_name == "req2_prop1"
|
||||
assert (
|
||||
path2_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path2_op.return_value.param_schema.properties["res2_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
|
||||
"""Test handling of duplicate parameter names (one in query, one in body).
|
||||
|
||||
The expected behavior is that both parameters should be captured but with
|
||||
different suffix, and
|
||||
their `original_name` attributes should reflect their origin (query or body).
|
||||
"""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/duplicate": {
|
||||
"post": {
|
||||
"operationId": "createWithDuplicate",
|
||||
"parameters": [{
|
||||
"name": "name",
|
||||
"in": "query",
|
||||
"schema": {"type": "string"},
|
||||
}],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "integer"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
op = parsed_operations[0]
|
||||
assert op.name == "create_with_duplicate"
|
||||
assert len(op.parameters) == 2
|
||||
|
||||
query_param = None
|
||||
body_param = None
|
||||
for param in op.parameters:
|
||||
if param.param_location == "query" and param.original_name == "name":
|
||||
query_param = param
|
||||
elif param.param_location == "body" and param.original_name == "name":
|
||||
body_param = param
|
||||
|
||||
assert query_param is not None
|
||||
assert query_param.original_name == "name"
|
||||
assert query_param.py_name == "name"
|
||||
|
||||
assert body_param is not None
|
||||
assert body_param.original_name == "name"
|
||||
assert body_param.py_name == "name_0"
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import ParameterInType
|
||||
from fastapi.openapi.models import SecuritySchemeType
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
def load_spec(file_path: str) -> Dict:
|
||||
"""Loads the OpenAPI specification from a YAML file."""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec() -> Dict:
|
||||
"""Fixture to load the OpenAPI specification."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Join the directory path with the filename
|
||||
yaml_path = os.path.join(current_dir, "test.yaml")
|
||||
return load_spec(yaml_path)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a dictionary."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a YAML string."""
|
||||
spec_str = yaml.dump(openapi_spec)
|
||||
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for an existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool_name = "calendar_calendars_insert" # Example operationId from the spec
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Creates a secondary calendar."
|
||||
assert tool.endpoint.method == "post"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.insert"
|
||||
assert tool.operation.description == "Creates a secondary calendar."
|
||||
assert isinstance(
|
||||
tool.operation.requestBody.content["application/json"], MediaType
|
||||
)
|
||||
assert len(tool.operation.responses) == 1
|
||||
response = tool.operation.responses["200"]
|
||||
assert response.description == "Successful response"
|
||||
assert isinstance(response.content["application/json"], MediaType)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
tool_name = "calendar_calendars_get"
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Returns metadata for a calendar."
|
||||
assert tool.endpoint.method == "get"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars/{calendarId}"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.get"
|
||||
assert tool.operation.description == "Returns metadata for a calendar."
|
||||
assert len(tool.operation.parameters) == 1
|
||||
assert tool.operation.parameters[0].name == "calendarId"
|
||||
assert tool.operation.parameters[0].in_ == ParameterInType.path
|
||||
assert tool.operation.parameters[0].required is True
|
||||
assert tool.operation.parameters[0].schema_.type == "string"
|
||||
assert (
|
||||
tool.operation.parameters[0].description
|
||||
== "Calendar identifier. To retrieve calendar IDs call the"
|
||||
" calendarList.list method. If you want to access the primary calendar"
|
||||
' of the currently logged in user, use the "primary" keyword.'
|
||||
)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for a non-existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool = toolset.get_tool("non_existent_tool")
|
||||
assert tool is None
|
||||
|
||||
|
||||
def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
|
||||
"""Test configuring auth during initialization."""
|
||||
|
||||
auth_scheme = APIKey(**{
|
||||
"in": APIKeyIn.header, # Use alias name in dict
|
||||
"name": "api_key",
|
||||
"type": SecuritySchemeType.http,
|
||||
})
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
toolset = OpenAPIToolset(
|
||||
spec_dict=openapi_spec,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
for tool in toolset.tools:
|
||||
assert tool.auth_scheme == auth_scheme
|
||||
assert tool.auth_credential == auth_credential
|
||||
@@ -0,0 +1,406 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation() -> Operation:
|
||||
"""Fixture to provide a sample OpenAPI Operation object."""
|
||||
return Operation(
|
||||
operationId='test_operation',
|
||||
summary='Test Summary',
|
||||
description='Test Description',
|
||||
parameters=[
|
||||
Parameter(**{
|
||||
'name': 'param1',
|
||||
'in': 'query',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 1',
|
||||
}),
|
||||
Parameter(**{
|
||||
'name': 'param2',
|
||||
'in': 'header',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 2',
|
||||
}),
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': Schema(
|
||||
type='string', description='Property 1'
|
||||
),
|
||||
'prop2': Schema(
|
||||
type='integer', description='Property 2'
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
},
|
||||
description='Request body description',
|
||||
),
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='string'))
|
||||
},
|
||||
),
|
||||
'400': Response(description='Client Error'),
|
||||
},
|
||||
security=[{'oauth2': ['resource: read', 'resource: write']}],
|
||||
)
|
||||
|
||||
|
||||
def test_operation_parser_initialization(sample_operation):
|
||||
"""Test initialization of OperationParser."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.operation == sample_operation
|
||||
assert len(parser.params) == 4 # 2 params + 2 request body props
|
||||
assert parser.return_value is not None
|
||||
|
||||
|
||||
def test_process_operation_parameters(sample_operation):
|
||||
"""Test _process_operation_parameters method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_operation_parameters()
|
||||
assert len(parser.params) == 2
|
||||
assert parser.params[0].original_name == 'param1'
|
||||
assert parser.params[0].param_location == 'query'
|
||||
assert parser.params[1].original_name == 'param2'
|
||||
assert parser.params[1].param_location == 'header'
|
||||
|
||||
|
||||
def test_process_request_body(sample_operation):
|
||||
"""Test _process_request_body method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 2 # 2 properties in request body
|
||||
assert parser.params[0].original_name == 'prop1'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
assert parser.params[1].original_name == 'prop2'
|
||||
assert parser.params[1].param_location == 'body'
|
||||
|
||||
|
||||
def test_process_request_body_array():
|
||||
"""Test _process_request_body method with array schema."""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='array',
|
||||
items=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'item_prop1': Schema(
|
||||
type='string', description='Item Property 1'
|
||||
),
|
||||
'item_prop2': Schema(
|
||||
type='integer', description='Item Property 2'
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'array'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
# Check that schema is correctly propagated and is a dictionary
|
||||
assert parser.params[0].param_schema.type == 'array'
|
||||
assert parser.params[0].param_schema.items.type == 'object'
|
||||
assert 'item_prop1' in parser.params[0].param_schema.items.properties
|
||||
assert 'item_prop2' in parser.params[0].param_schema.items.properties
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop1'].description
|
||||
== 'Item Property 1'
|
||||
)
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop2'].description
|
||||
== 'Item Property 2'
|
||||
)
|
||||
|
||||
|
||||
def test_process_request_body_no_name():
|
||||
"""Test _process_request_body with a schema that has no properties (unnamed)"""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={'application/json': MediaType(schema=Schema(type='string'))}
|
||||
)
|
||||
)
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == '' # No name
|
||||
assert parser.params[0].param_location == 'body'
|
||||
|
||||
|
||||
def test_dedupe_param_names(sample_operation):
|
||||
"""Test _dedupe_param_names method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
# Add duplicate named parameters.
|
||||
parser.params = [
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
]
|
||||
parser._dedupe_param_names()
|
||||
assert parser.params[0].py_name == 'test'
|
||||
assert parser.params[1].py_name == 'test_0'
|
||||
assert parser.params[2].py_name == 'test_1'
|
||||
|
||||
|
||||
def test_process_return_value(sample_operation):
|
||||
"""Test _process_return_value method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
|
||||
|
||||
def test_process_return_value_no_2xx(sample_operation):
|
||||
"""Tests _process_return_value when no 2xx response exists."""
|
||||
operation_no_2xx = Operation(
|
||||
responses={'400': Response(description='Client Error')}
|
||||
)
|
||||
parser = OperationParser(operation_no_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_multiple_2xx(sample_operation):
|
||||
"""Tests _process_return_value when multiple 2xx responses exist."""
|
||||
operation_multi_2xx = Operation(
|
||||
responses={
|
||||
'201': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='integer'))
|
||||
},
|
||||
),
|
||||
'202': Response(
|
||||
description='Success',
|
||||
content={'text/plain': MediaType(schema=Schema(type='string'))},
|
||||
),
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/pdf': MediaType(schema=Schema(type='boolean'))
|
||||
},
|
||||
),
|
||||
'400': Response(
|
||||
description='Failure',
|
||||
content={
|
||||
'application/xml': MediaType(schema=Schema(type='object'))
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
parser = OperationParser(operation_multi_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
|
||||
assert parser.return_value is not None
|
||||
# Take the content type of the 200 response since it's the smallest response
|
||||
# code
|
||||
assert parser.return_value.param_schema.type == 'boolean'
|
||||
|
||||
|
||||
def test_process_return_value_no_content(sample_operation):
|
||||
"""Test when 2xx response has no content"""
|
||||
operation_no_content = Operation(
|
||||
responses={'200': Response(description='Success', content={})}
|
||||
)
|
||||
parser = OperationParser(operation_no_content, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_no_schema(sample_operation):
|
||||
"""Tests when the 2xx response's content has no schema."""
|
||||
operation_no_schema = Operation(
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={'application/json': MediaType(schema=None)},
|
||||
)
|
||||
}
|
||||
)
|
||||
parser = OperationParser(operation_no_schema, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_get_function_name(sample_operation):
|
||||
"""Test get_function_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_function_name() == 'test_operation'
|
||||
|
||||
|
||||
def test_get_function_name_missing_id():
|
||||
"""Tests get_function_name when operationId is missing"""
|
||||
operation = Operation() # No ID
|
||||
parser = OperationParser(operation)
|
||||
with pytest.raises(ValueError, match='Operation ID is missing'):
|
||||
parser.get_function_name()
|
||||
|
||||
|
||||
def test_get_return_type_hint(sample_operation):
|
||||
"""Test get_return_type_hint method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_hint() == 'str'
|
||||
|
||||
|
||||
def test_get_return_type_value(sample_operation):
|
||||
"""Test get_return_type_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_value() == str
|
||||
|
||||
|
||||
def test_get_parameters(sample_operation):
|
||||
"""Test get_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
params = parser.get_parameters()
|
||||
assert len(params) == 4 # Correct count after processing
|
||||
assert all(isinstance(p, ApiParameter) for p in params)
|
||||
|
||||
|
||||
def test_get_return_value(sample_operation):
|
||||
"""Test get_return_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
return_value = parser.get_return_value()
|
||||
assert isinstance(return_value, ApiParameter)
|
||||
|
||||
|
||||
def test_get_auth_scheme_name(sample_operation):
|
||||
"""Test get_auth_scheme_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_auth_scheme_name() == 'oauth2'
|
||||
|
||||
|
||||
def test_get_auth_scheme_name_no_security():
|
||||
"""Test get_auth_scheme_name when no security is present."""
|
||||
operation = Operation(responses={})
|
||||
parser = OperationParser(operation)
|
||||
assert parser.get_auth_scheme_name() == ''
|
||||
|
||||
|
||||
def test_get_pydoc_string(sample_operation):
|
||||
"""Test get_pydoc_string method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
pydoc_string = parser.get_pydoc_string()
|
||||
assert 'Test Summary' in pydoc_string
|
||||
assert 'Args:' in pydoc_string
|
||||
assert 'param1 (str): Parameter 1' in pydoc_string
|
||||
assert 'prop1 (str): Property 1' in pydoc_string
|
||||
assert 'Returns (str):' in pydoc_string
|
||||
assert 'Success' in pydoc_string
|
||||
|
||||
|
||||
def test_get_json_schema(sample_operation):
|
||||
"""Test get_json_schema method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
json_schema = parser.get_json_schema()
|
||||
assert json_schema['title'] == 'test_operation_Arguments'
|
||||
assert json_schema['type'] == 'object'
|
||||
assert 'param1' in json_schema['properties']
|
||||
assert 'prop1' in json_schema['properties']
|
||||
assert 'param1' in json_schema['required']
|
||||
assert 'prop1' in json_schema['required']
|
||||
|
||||
|
||||
def test_get_signature_parameters(sample_operation):
|
||||
"""Test get_signature_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
signature_params = parser.get_signature_parameters()
|
||||
assert len(signature_params) == 4
|
||||
assert signature_params[0].name == 'param1'
|
||||
assert signature_params[0].annotation == str
|
||||
assert signature_params[2].name == 'prop1'
|
||||
assert signature_params[2].annotation == str
|
||||
|
||||
|
||||
def test_get_annotations(sample_operation):
|
||||
"""Test get_annotations method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
annotations = parser.get_annotations()
|
||||
assert len(annotations) == 5 # 4 parameters + return
|
||||
assert annotations['param1'] == str
|
||||
assert annotations['prop1'] == str
|
||||
assert annotations['return'] == str
|
||||
|
||||
|
||||
def test_load():
|
||||
"""Test the load classmethod."""
|
||||
operation = Operation(operationId='my_op') # Minimal operation
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name='p1',
|
||||
param_location='',
|
||||
param_schema={'type': 'integer'},
|
||||
)
|
||||
]
|
||||
return_value = ApiParameter(
|
||||
original_name='', param_location='', param_schema={'type': 'string'}
|
||||
)
|
||||
|
||||
parser = OperationParser.load(operation, params, return_value)
|
||||
|
||||
assert isinstance(parser, OperationParser)
|
||||
assert parser.operation == operation
|
||||
assert parser.params == params
|
||||
assert parser.return_value == return_value
|
||||
assert (
|
||||
parser.get_function_name() == 'my_op'
|
||||
) # Check that the operation is loaded
|
||||
|
||||
|
||||
def test_operation_parser_with_dict():
|
||||
"""Test initialization of OperationParser with a dictionary."""
|
||||
operation_dict = {
|
||||
'operationId': 'test_dict_operation',
|
||||
'parameters': [
|
||||
{'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
|
||||
],
|
||||
'responses': {
|
||||
'200': {
|
||||
'description': 'Dict Success',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
parser = OperationParser(operation_dict)
|
||||
assert parser.operation.operationId == 'test_dict_operation'
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'dict_param'
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
@@ -0,0 +1,966 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter as OpenAPIParameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Schema as OpenAPISchema
|
||||
from google.adk.sessions.state import State
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
from google.genai.types import Type
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRestApiTool:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_context = MagicMock(spec=ToolContext)
|
||||
mock_context.state = State({}, {})
|
||||
mock_context.get_auth_response.return_value = {}
|
||||
mock_context.request_credential.return_value = {}
|
||||
return mock_context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_operation_parser(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_parser = MagicMock(spec=OperationParser)
|
||||
mock_parser.get_function_name.return_value = "mock_function_name"
|
||||
mock_parser.get_json_schema.return_value = {}
|
||||
mock_parser.get_parameters.return_value = []
|
||||
mock_parser.get_return_type_hint.return_value = "str"
|
||||
mock_parser.get_pydoc_string.return_value = "Mock docstring"
|
||||
mock_parser.get_signature_parameters.return_value = []
|
||||
mock_parser.get_return_type_value.return_value = str
|
||||
mock_parser.get_annotations.return_value = {}
|
||||
return mock_parser
|
||||
|
||||
@pytest.fixture
|
||||
def sample_endpiont(self):
|
||||
return OperationEndpoint(
|
||||
base_url="https://example.com", path="/test", method="GET"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation(self):
|
||||
return Operation(
|
||||
operationId="testOperation",
|
||||
description="Test operation",
|
||||
parameters=[],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"testBodyParam": OpenAPISchema(type="string")
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_api_parameters(self):
|
||||
return [
|
||||
ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="test_body_param",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_return_parameter(self):
|
||||
return ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_scheme(self):
|
||||
scheme, _ = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return scheme
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_credential(self):
|
||||
_, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return credential
|
||||
|
||||
def test_init(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test Tool"
|
||||
assert tool.endpoint == sample_endpiont
|
||||
assert tool.operation == sample_operation
|
||||
assert tool.auth_credential == sample_auth_credential
|
||||
assert tool.auth_scheme == sample_auth_scheme
|
||||
assert tool.credential_exchanger is not None
|
||||
|
||||
def test_from_parsed_operation_str(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_api_parameters,
|
||||
sample_return_parameter,
|
||||
sample_operation,
|
||||
):
|
||||
parsed_operation_str = json.dumps({
|
||||
"name": "test_operation",
|
||||
"description": "Test Description",
|
||||
"endpoint": sample_endpiont.model_dump(),
|
||||
"operation": sample_operation.model_dump(),
|
||||
"auth_scheme": None,
|
||||
"auth_credential": None,
|
||||
"parameters": [p.model_dump() for p in sample_api_parameters],
|
||||
"return_value": sample_return_parameter.model_dump(),
|
||||
})
|
||||
|
||||
tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
|
||||
assert tool.name == "test_operation"
|
||||
|
||||
def test_get_declaration(
|
||||
self, sample_endpiont, sample_operation, mock_operation_parser
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test description",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
should_parse_operation=False,
|
||||
)
|
||||
tool._operation_parser = mock_operation_parser
|
||||
|
||||
declaration = tool._get_declaration()
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
assert declaration.name == "test_tool"
|
||||
assert declaration.description == "Test description"
|
||||
assert isinstance(declaration.parameters, Schema)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_success(
|
||||
self,
|
||||
mock_request,
|
||||
mock_tool_context,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tool.call(args={}, tool_context=mock_tool_context)
|
||||
|
||||
# Check the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_auth_pending(
|
||||
self,
|
||||
mock_request,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
with patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||
) as mock_from_tool_context:
|
||||
mock_tool_auth_handler_instance = MagicMock()
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"pending"
|
||||
)
|
||||
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||
|
||||
response = tool.call(args={}, tool_context=None)
|
||||
assert response == {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
def test_prepare_request_params_query_body(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Create a mock Operation object
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
parameters=[
|
||||
OpenAPIParameter(**{
|
||||
"name": "testQueryParam",
|
||||
"in": "query",
|
||||
"schema": OpenAPISchema(type="string"),
|
||||
})
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"param1": OpenAPISchema(type="string"),
|
||||
"param2": OpenAPISchema(type="integer"),
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param1",
|
||||
py_name="param1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="param2",
|
||||
py_name="param2",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="integer"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="testQueryParam",
|
||||
py_name="test_query_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
]
|
||||
kwargs = {
|
||||
"param1": "value1",
|
||||
"param2": 123,
|
||||
"test_query_param": "query_value",
|
||||
}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
assert request_params["method"] == "get"
|
||||
assert request_params["url"] == "https://example.com/test"
|
||||
assert request_params["json"] == {"param1": "value1", "param2": 123}
|
||||
assert request_params["params"] == {"testQueryParam": "query_value"}
|
||||
|
||||
def test_prepare_request_params_array(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="array", # Match the parameter name
|
||||
py_name="array",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
),
|
||||
)
|
||||
]
|
||||
kwargs = {"array": ["item1", "item2"]}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["json"] == ["item1", "item2"]
|
||||
|
||||
def test_prepare_request_params_string(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string"))
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input_string",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input_string": "test_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == "test_value"
|
||||
assert request_params["headers"]["Content-Type"] == "text/plain"
|
||||
|
||||
def test_prepare_request_params_form_data(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/x-www-form-urlencoded": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={"key1": OpenAPISchema(type="string")},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="key1",
|
||||
py_name="key1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"key1": "value1"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == {"key1": "value1"}
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"]
|
||||
== "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_multipart(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"multipart/form-data": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"file1": OpenAPISchema(
|
||||
type="string", format="binary"
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="file1",
|
||||
py_name="file1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"file1": b"file_content"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["files"] == {"file1": b"file_content"}
|
||||
assert request_params["headers"]["Content-Type"] == "multipart/form-data"
|
||||
|
||||
def test_prepare_request_params_octet_stream(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/octet-stream": MediaType(
|
||||
schema=OpenAPISchema(type="string", format="binary")
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="data",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"data": b"binary_data"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == b"binary_data"
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"] == "application/octet-stream"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_path_param(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(operationId="test_op")
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="user_id",
|
||||
py_name="user_id",
|
||||
param_location="path",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"user_id": "123"}
|
||||
endpoint_with_path = OperationEndpoint(
|
||||
base_url="https://example.com", path="/test/{user_id}", method="get"
|
||||
)
|
||||
tool.endpoint = endpoint_with_path
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert (
|
||||
request_params["url"] == "https://example.com/test/123"
|
||||
) # Path param replaced
|
||||
|
||||
def test_prepare_request_params_header_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="X-Custom-Header",
|
||||
py_name="x_custom_header",
|
||||
param_location="header",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"x_custom_header": "header_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["X-Custom-Header"] == "header_value"
|
||||
|
||||
def test_prepare_request_params_cookie_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="session_id",
|
||||
py_name="session_id",
|
||||
param_location="cookie",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"session_id": "cookie_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["cookies"]["session_id"] == "cookie_value"
|
||||
|
||||
def test_prepare_request_params_multiple_mime_types(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Test what happens when multiple mime types are specified. It should take
|
||||
# the first one.
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(type="string")
|
||||
),
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string")),
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input": "some_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
def test_prepare_request_params_unknown_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="known_param",
|
||||
py_name="known_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"known_param": "value", "unknown_param": "unknown"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Make sure unknown parameters are ignored and do not raise errors.
|
||||
assert "unknown_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_base_url_handling(
|
||||
self, sample_auth_credential, sample_auth_scheme, sample_operation
|
||||
):
|
||||
# No base_url provided, should use path as is
|
||||
tool_no_base = RestApiTool(
|
||||
name="test_tool_no_base",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = []
|
||||
kwargs = {}
|
||||
|
||||
request_params_no_base = tool_no_base._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_no_base["url"] == "/no_base"
|
||||
|
||||
tool_trailing_slash = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(
|
||||
base_url="https://example.com/", path="/trailing", method="get"
|
||||
),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
request_params_trailing = tool_trailing_slash._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_trailing["url"] == "https://example.com/trailing"
|
||||
|
||||
def test_prepare_request_params_no_unrecognized_query_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="unrecognized_param",
|
||||
py_name="unrecognized_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"unrecognized_param": None} # Explicitly passing None
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Query param not in sample_operation. It should be ignored.
|
||||
assert "unrecognized_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_no_credential(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=None,
|
||||
auth_scheme=None,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param_name",
|
||||
py_name="param_name",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"param_name": "aaa", "empty_param": ""}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert "param_name" in request_params["params"]
|
||||
assert "empty_param" not in request_params["params"]
|
||||
|
||||
|
||||
class TestToGeminiSchema:
|
||||
|
||||
def test_to_gemini_schema_none(self):
|
||||
assert to_gemini_schema(None) is None
|
||||
|
||||
def test_to_gemini_schema_not_dict(self):
|
||||
with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
|
||||
to_gemini_schema("not a dict")
|
||||
|
||||
def test_to_gemini_schema_empty_dict(self):
|
||||
result = to_gemini_schema({})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_dict_with_only_object_type(self):
|
||||
result = to_gemini_schema({"type": "object"})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_basic_types(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"is_active": {"type": "boolean"},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert isinstance(gemini_schema, Schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["age"].type == Type.INTEGER
|
||||
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
|
||||
|
||||
def test_to_gemini_schema_nested_objects(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["address"].type == Type.OBJECT
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["street"].type
|
||||
== Type.STRING
|
||||
)
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["city"].type
|
||||
== Type.STRING
|
||||
)
|
||||
|
||||
def test_to_gemini_schema_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.ARRAY
|
||||
assert gemini_schema.items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_nested_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.items.properties["name"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_any_of(self):
|
||||
openapi_schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert len(gemini_schema.any_of) == 2
|
||||
assert gemini_schema.any_of[0].type == Type.STRING
|
||||
assert gemini_schema.any_of[1].type == Type.INTEGER
|
||||
|
||||
def test_to_gemini_schema_general_list(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"properties": {
|
||||
"list_field": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["list_field"].type == Type.ARRAY
|
||||
assert gemini_schema.properties["list_field"].items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_enum(self):
|
||||
openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.enum == ["a", "b", "c"]
|
||||
|
||||
def test_to_gemini_schema_required(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.required == ["name"]
|
||||
|
||||
def test_to_gemini_schema_nested_dict(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {"metadata": {"key1": "value1", "key2": 123}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
|
||||
assert isinstance(gemini_schema.properties["metadata"], Schema)
|
||||
assert (
|
||||
gemini_schema.properties["metadata"].type == Type.OBJECT
|
||||
) # add object type by default
|
||||
assert gemini_schema.properties["metadata"].properties == {
|
||||
"dummy_DO_NOT_GENERATE": Schema(type="string")
|
||||
}
|
||||
|
||||
def test_to_gemini_schema_ignore_title_default_format(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"title": "Test Title",
|
||||
"default": "default_value",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
|
||||
assert gemini_schema.title is None
|
||||
assert gemini_schema.default is None
|
||||
assert gemini_schema.format is None
|
||||
|
||||
def test_to_gemini_schema_property_ordering(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"propertyOrdering": ["name", "age"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.property_ordering == ["name", "age"]
|
||||
|
||||
def test_to_gemini_schema_converts_property_dict(self):
|
||||
openapi_schema = {
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "The property key"},
|
||||
"value": {"type": "string", "description": "The property value"},
|
||||
},
|
||||
"type": "object",
|
||||
"description": "A single property entry in the Properties message.",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["value"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_remove_unrecognized_fields(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"description": "A single date string.",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.STRING
|
||||
assert not gemini_schema.format
|
||||
|
||||
|
||||
def test_snake_to_lower_camel():
|
||||
assert snake_to_lower_camel("single") == "single"
|
||||
assert snake_to_lower_camel("two_words") == "twoWords"
|
||||
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
|
||||
assert not snake_to_lower_camel("")
|
||||
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
|
||||
# Helper function to create a mock ToolContext
|
||||
def create_mock_tool_context():
|
||||
return ToolContext(
|
||||
function_call_id='test-fc-id',
|
||||
invocation_context=InvocationContext(
|
||||
agent=LlmAgent(name='test'),
|
||||
session=Session(app_name='test', user_id='123', id='123'),
|
||||
invocation_id='123',
|
||||
session_service=InMemorySessionService(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Test cases for OpenID Connect
|
||||
class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
|
||||
|
||||
def __init__(
|
||||
self, expected_scheme, expected_credential, expected_access_token
|
||||
):
|
||||
self.expected_scheme = expected_scheme
|
||||
self.expected_credential = expected_credential
|
||||
self.expected_access_token = expected_access_token
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
if auth_credential.oauth2 and (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
or auth_credential.oauth2.auth_code
|
||||
):
|
||||
auth_code = (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
if auth_credential.oauth2.auth_response_uri
|
||||
else auth_credential.oauth2.auth_code
|
||||
)
|
||||
# Simulate the token exchange
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme='bearer',
|
||||
credentials=HttpCredentials(
|
||||
token=auth_code + self.expected_access_token
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
# simulate the case of getting auth_uri
|
||||
return None
|
||||
|
||||
|
||||
def get_mock_openid_scheme_credential():
|
||||
config_dict = {
|
||||
'authorization_endpoint': 'test.com',
|
||||
'token_endpoint': 'test.com',
|
||||
}
|
||||
scopes = ['test_scope']
|
||||
credential_dict = {
|
||||
'client_id': '123',
|
||||
'client_secret': '456',
|
||||
'redirect_uri': 'test.com',
|
||||
}
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
# Fixture for the OpenID Connect security scheme
|
||||
@pytest.fixture
|
||||
def openid_connect_scheme():
|
||||
scheme, _ = get_mock_openid_scheme_credential()
|
||||
return scheme
|
||||
|
||||
|
||||
# Fixture for a base OpenID Connect credential
|
||||
@pytest.fixture
|
||||
def openid_connect_credential():
|
||||
_, credential = get_mock_openid_scheme_credential()
|
||||
return credential
|
||||
|
||||
|
||||
def test_openid_connect_no_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
# Setup Mock exchanger
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme, openid_connect_credential, None
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'pending'
|
||||
assert result.auth_credential == openid_connect_credential
|
||||
|
||||
|
||||
def test_openid_connect_with_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential, monkeypatch
|
||||
):
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
'test_access_token',
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
|
||||
mock_auth_handler = MagicMock()
|
||||
mock_auth_handler.get_auth_response.return_value = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
|
||||
)
|
||||
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
|
||||
monkeypatch.setattr(
|
||||
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
|
||||
)
|
||||
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert 'test_access_token' in result.auth_credential.http.credentials.token
|
||||
# Verify that the credential was stored:
|
||||
stored_credential = credential_store.get_credential(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
assert stored_credential == result.auth_credential
|
||||
mock_auth_handler.get_auth_response.assert_called_once()
|
||||
|
||||
|
||||
def test_openid_connect_existing_token(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
_, existing_credential = token_to_scheme_credential(
|
||||
'oauth2Token', 'header', 'bearer', '123123123'
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
# Store the credential to simulate existing credential
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
key = credential_store.get_credential_key(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
credential_store.store_credential(key, existing_credential)
|
||||
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential == existing_credential
|
||||
Reference in New Issue
Block a user