mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package. * Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py * Adds github workflow * minor fix in lite_llm.py for python 3.9. * format pyproject.toml
This commit is contained in:
499
tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
Normal file
499
tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
Normal file
@@ -0,0 +1,499 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
# Mock data for API responses
|
||||
MOCK_API_LIST = {
|
||||
"apis": [
|
||||
{"name": "projects/test-project/locations/us-central1/apis/api1"},
|
||||
{"name": "projects/test-project/locations/us-central1/apis/api2"},
|
||||
]
|
||||
}
|
||||
MOCK_API_DETAIL = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1",
|
||||
"versions": [
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
],
|
||||
}
|
||||
MOCK_API_VERSION = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"specs": [
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
],
|
||||
}
|
||||
MOCK_SPEC_CONTENT = {"contents": base64.b64encode(b"spec content").decode()}
|
||||
|
||||
|
||||
# Test cases
|
||||
class TestAPIHubClient:
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return APIHubClient(access_token="mocked_token")
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_config(self):
|
||||
return json.dumps({
|
||||
"type": "service_account",
|
||||
"project_id": "test",
|
||||
"token_uri": "test.com",
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "1234",
|
||||
})
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_LIST
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
apis = client.list_apis("test-project", "us-central1")
|
||||
assert apis == MOCK_API_LIST["apis"]
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis_empty(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {"apis": []}
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
apis = client.list_apis("test-project", "us-central1")
|
||||
assert apis == []
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
client.list_apis("test-project", "us-central1")
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_DETAIL
|
||||
mock_get.return_value.status_code = 200
|
||||
api = client.get_api(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
assert api == MOCK_API_DETAIL
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_api("projects/test-project/locations/us-central1/apis/api1")
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_version(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_VERSION
|
||||
mock_get.return_value.status_code = 200
|
||||
api_version = client.get_api_version(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
assert api_version == MOCK_API_VERSION
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_version_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_api_version(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
|
||||
mock_get.return_value.status_code = 200
|
||||
spec_content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert spec_content == "spec content"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1:contents",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_empty(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {"contents": ""}
|
||||
mock_get.return_value.status_code = 200
|
||||
spec_content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert spec_content == ""
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url_or_path, expected",
|
||||
[
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1?project=test-project",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1?project=test-project",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"/projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
( # Added trailing slashes
|
||||
"projects/test-project/locations/us-central1/apis/api1/",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
( # case location name
|
||||
"projects/test-project/locations/LOCATION/apis/api1/",
|
||||
(
|
||||
"projects/test-project/locations/LOCATION/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
|
||||
(
|
||||
"projects/p1/locations/l1/apis/a1",
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1",
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_resource_name(self, client, url_or_path, expected):
|
||||
result = client._extract_resource_name(url_or_path)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url_or_path, expected_error_message",
|
||||
[
|
||||
(
|
||||
"invalid-path",
|
||||
"Project ID not found in URL or path in APIHubClient.",
|
||||
),
|
||||
(
|
||||
"projects/test-project",
|
||||
"Location not found in URL or path in APIHubClient.",
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1",
|
||||
"API id not found in URL or path in APIHubClient.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_resource_name_invalid(
|
||||
self, client, url_or_path, expected_error_message
|
||||
):
|
||||
with pytest.raises(ValueError, match=expected_error_message):
|
||||
client._extract_resource_name(url_or_path)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
|
||||
)
|
||||
def test_get_access_token_use_default_credential(
|
||||
self,
|
||||
mock_from_service_account_info,
|
||||
mock_default_service_credential,
|
||||
):
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.token = "default_token"
|
||||
mock_default_service_credential.return_value = (
|
||||
mock_credential,
|
||||
"project_id",
|
||||
)
|
||||
mock_config_credential = MagicMock()
|
||||
mock_config_credential.token = "config_token"
|
||||
mock_from_service_account_info.return_value = mock_config_credential
|
||||
|
||||
client = APIHubClient()
|
||||
token = client._get_access_token()
|
||||
assert token == "default_token"
|
||||
mock_credential.refresh.assert_called_once()
|
||||
assert client.credential_cache == mock_credential
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
|
||||
)
|
||||
def test_get_access_token_use_configured_service_account(
|
||||
self,
|
||||
mock_from_service_account_info,
|
||||
mock_default_service_credential,
|
||||
service_account_config,
|
||||
):
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.token = "default_token"
|
||||
mock_default_service_credential.return_value = (
|
||||
mock_credential,
|
||||
"project_id",
|
||||
)
|
||||
mock_config_credential = MagicMock()
|
||||
mock_config_credential.token = "config_token"
|
||||
mock_from_service_account_info.return_value = mock_config_credential
|
||||
|
||||
client = APIHubClient(service_account_json=service_account_config)
|
||||
token = client._get_access_token()
|
||||
|
||||
assert token == "config_token"
|
||||
mock_from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_config),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_config_credential.refresh.assert_called_once()
|
||||
assert client.credential_cache == mock_config_credential
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_not_expired_use_cached_token(
|
||||
self, mock_default_credential
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "default_service_account_token"
|
||||
mock_default_credential.return_value = (mock_credentials, "")
|
||||
|
||||
client = APIHubClient()
|
||||
# Call #1: Setup cache
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
# Call #2: Reuse cache
|
||||
mock_credentials.reset_mock()
|
||||
mock_credentials.expired = False
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_credentials.refresh.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_expired_refresh(self, mock_default_credential):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "default_service_account_token"
|
||||
mock_default_credential.return_value = (mock_credentials, "")
|
||||
client = APIHubClient()
|
||||
|
||||
# Call #1: Setup cache
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
# Call #2: Cache expired
|
||||
mock_credentials.reset_mock()
|
||||
mock_credentials.expired = True
|
||||
token = client._get_access_token()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
assert token == "default_service_account_token"
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_no_credentials(
|
||||
self, mock_default_service_credential
|
||||
):
|
||||
mock_default_service_credential.return_value = (None, None)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
),
|
||||
):
|
||||
# no service account client
|
||||
APIHubClient()._get_access_token()
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_api_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL), # For get_api
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_API_VERSION
|
||||
), # For get_api_version
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_SPEC_CONTENT
|
||||
), # For get_spec_content
|
||||
]
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
# Check calls - get_api, get_api_version, then get_spec_content
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_version_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_API_VERSION
|
||||
), # For get_api_version
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_SPEC_CONTENT
|
||||
), # For get_spec_content
|
||||
]
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
assert mock_get.call_count == 2 # get_api_version and get_spec_content
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_spec_level(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
mock_get.assert_called_once() # Only get_spec_content should be called
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_no_versions(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1",
|
||||
"versions": [],
|
||||
} # No versions
|
||||
mock_get.return_value.status_code = 200
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No versions found in API Hub resource:"
|
||||
" projects/test-project/locations/us-central1/apis/api1"
|
||||
),
|
||||
):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_no_specs(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"name": (
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
),
|
||||
"specs": [],
|
||||
},
|
||||
), # No specs
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No specs found in API Hub version:"
|
||||
" projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
),
|
||||
):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_invalid_path(self, mock_get, client):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Project ID not found in URL or path in APIHubClient. Input"
|
||||
" path is 'invalid-path'."
|
||||
),
|
||||
):
|
||||
client.get_spec_content("invalid-path")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user