mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package. * Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py * Adds github workflow * minor fix in lite_llm.py for python 3.9. * format pyproject.toml
This commit is contained in:
14
tests/unittests/tools/__init__.py
Normal file
14
tests/unittests/tools/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
499
tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
Normal file
499
tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
Normal file
@@ -0,0 +1,499 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
# Mock data for API responses
|
||||
MOCK_API_LIST = {
|
||||
"apis": [
|
||||
{"name": "projects/test-project/locations/us-central1/apis/api1"},
|
||||
{"name": "projects/test-project/locations/us-central1/apis/api2"},
|
||||
]
|
||||
}
|
||||
MOCK_API_DETAIL = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1",
|
||||
"versions": [
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
],
|
||||
}
|
||||
MOCK_API_VERSION = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"specs": [
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
],
|
||||
}
|
||||
MOCK_SPEC_CONTENT = {"contents": base64.b64encode(b"spec content").decode()}
|
||||
|
||||
|
||||
# Test cases
|
||||
class TestAPIHubClient:
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return APIHubClient(access_token="mocked_token")
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_config(self):
|
||||
return json.dumps({
|
||||
"type": "service_account",
|
||||
"project_id": "test",
|
||||
"token_uri": "test.com",
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "1234",
|
||||
})
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_LIST
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
apis = client.list_apis("test-project", "us-central1")
|
||||
assert apis == MOCK_API_LIST["apis"]
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis_empty(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {"apis": []}
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
apis = client.list_apis("test-project", "us-central1")
|
||||
assert apis == []
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
client.list_apis("test-project", "us-central1")
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_DETAIL
|
||||
mock_get.return_value.status_code = 200
|
||||
api = client.get_api(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
assert api == MOCK_API_DETAIL
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_api("projects/test-project/locations/us-central1/apis/api1")
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_version(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_VERSION
|
||||
mock_get.return_value.status_code = 200
|
||||
api_version = client.get_api_version(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
assert api_version == MOCK_API_VERSION
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_version_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_api_version(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
|
||||
mock_get.return_value.status_code = 200
|
||||
spec_content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert spec_content == "spec content"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1:contents",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_empty(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {"contents": ""}
|
||||
mock_get.return_value.status_code = 200
|
||||
spec_content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert spec_content == ""
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url_or_path, expected",
|
||||
[
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1?project=test-project",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1?project=test-project",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"/projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
( # Added trailing slashes
|
||||
"projects/test-project/locations/us-central1/apis/api1/",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
( # case location name
|
||||
"projects/test-project/locations/LOCATION/apis/api1/",
|
||||
(
|
||||
"projects/test-project/locations/LOCATION/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
|
||||
(
|
||||
"projects/p1/locations/l1/apis/a1",
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1",
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_resource_name(self, client, url_or_path, expected):
|
||||
result = client._extract_resource_name(url_or_path)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url_or_path, expected_error_message",
|
||||
[
|
||||
(
|
||||
"invalid-path",
|
||||
"Project ID not found in URL or path in APIHubClient.",
|
||||
),
|
||||
(
|
||||
"projects/test-project",
|
||||
"Location not found in URL or path in APIHubClient.",
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1",
|
||||
"API id not found in URL or path in APIHubClient.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_resource_name_invalid(
|
||||
self, client, url_or_path, expected_error_message
|
||||
):
|
||||
with pytest.raises(ValueError, match=expected_error_message):
|
||||
client._extract_resource_name(url_or_path)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
|
||||
)
|
||||
def test_get_access_token_use_default_credential(
|
||||
self,
|
||||
mock_from_service_account_info,
|
||||
mock_default_service_credential,
|
||||
):
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.token = "default_token"
|
||||
mock_default_service_credential.return_value = (
|
||||
mock_credential,
|
||||
"project_id",
|
||||
)
|
||||
mock_config_credential = MagicMock()
|
||||
mock_config_credential.token = "config_token"
|
||||
mock_from_service_account_info.return_value = mock_config_credential
|
||||
|
||||
client = APIHubClient()
|
||||
token = client._get_access_token()
|
||||
assert token == "default_token"
|
||||
mock_credential.refresh.assert_called_once()
|
||||
assert client.credential_cache == mock_credential
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
|
||||
)
|
||||
def test_get_access_token_use_configured_service_account(
|
||||
self,
|
||||
mock_from_service_account_info,
|
||||
mock_default_service_credential,
|
||||
service_account_config,
|
||||
):
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.token = "default_token"
|
||||
mock_default_service_credential.return_value = (
|
||||
mock_credential,
|
||||
"project_id",
|
||||
)
|
||||
mock_config_credential = MagicMock()
|
||||
mock_config_credential.token = "config_token"
|
||||
mock_from_service_account_info.return_value = mock_config_credential
|
||||
|
||||
client = APIHubClient(service_account_json=service_account_config)
|
||||
token = client._get_access_token()
|
||||
|
||||
assert token == "config_token"
|
||||
mock_from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_config),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_config_credential.refresh.assert_called_once()
|
||||
assert client.credential_cache == mock_config_credential
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_not_expired_use_cached_token(
|
||||
self, mock_default_credential
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "default_service_account_token"
|
||||
mock_default_credential.return_value = (mock_credentials, "")
|
||||
|
||||
client = APIHubClient()
|
||||
# Call #1: Setup cache
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
# Call #2: Reuse cache
|
||||
mock_credentials.reset_mock()
|
||||
mock_credentials.expired = False
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_credentials.refresh.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_expired_refresh(self, mock_default_credential):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "default_service_account_token"
|
||||
mock_default_credential.return_value = (mock_credentials, "")
|
||||
client = APIHubClient()
|
||||
|
||||
# Call #1: Setup cache
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
# Call #2: Cache expired
|
||||
mock_credentials.reset_mock()
|
||||
mock_credentials.expired = True
|
||||
token = client._get_access_token()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
assert token == "default_service_account_token"
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_no_credentials(
|
||||
self, mock_default_service_credential
|
||||
):
|
||||
mock_default_service_credential.return_value = (None, None)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
),
|
||||
):
|
||||
# no service account client
|
||||
APIHubClient()._get_access_token()
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_api_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL), # For get_api
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_API_VERSION
|
||||
), # For get_api_version
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_SPEC_CONTENT
|
||||
), # For get_spec_content
|
||||
]
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
# Check calls - get_api, get_api_version, then get_spec_content
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_version_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_API_VERSION
|
||||
), # For get_api_version
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_SPEC_CONTENT
|
||||
), # For get_spec_content
|
||||
]
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
assert mock_get.call_count == 2 # get_api_version and get_spec_content
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_spec_level(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
mock_get.assert_called_once() # Only get_spec_content should be called
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_no_versions(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1",
|
||||
"versions": [],
|
||||
} # No versions
|
||||
mock_get.return_value.status_code = 200
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No versions found in API Hub resource:"
|
||||
" projects/test-project/locations/us-central1/apis/api1"
|
||||
),
|
||||
):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_no_specs(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"name": (
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
),
|
||||
"specs": [],
|
||||
},
|
||||
), # No specs
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No specs found in API Hub version:"
|
||||
" projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
),
|
||||
):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_invalid_path(self, mock_get, client):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Project ID not found in URL or path in APIHubClient. Input"
|
||||
" path is 'invalid-path'."
|
||||
),
|
||||
):
|
||||
client.get_spec_content("invalid-path")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
204
tests/unittests/tools/apihub_tool/test_apihub_toolset.py
Normal file
204
tests/unittests/tools/apihub_tool/test_apihub_toolset.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.apihub_tool.apihub_toolset import APIHubToolset
|
||||
from google.adk.tools.apihub_tool.clients.apihub_client import BaseAPIHubClient
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
class MockAPIHubClient(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
title: Mock API
|
||||
description: Mock API Description
|
||||
paths:
|
||||
/test:
|
||||
get:
|
||||
summary: Test GET endpoint
|
||||
operationId: testGet
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
"""
|
||||
|
||||
|
||||
# Fixture for a basic APIHubToolset
|
||||
@pytest.fixture
|
||||
def basic_apihub_toolset():
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource', apihub_client=apihub_client
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
# Fixture for an APIHubToolset with lazy loading
|
||||
@pytest.fixture
|
||||
def lazy_apihub_toolset():
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
lazy_load_spec=True,
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
# Fixture for auth scheme
|
||||
@pytest.fixture
|
||||
def mock_auth_scheme():
|
||||
return MagicMock(spec=AuthScheme)
|
||||
|
||||
|
||||
# Fixture for auth credential
|
||||
@pytest.fixture
|
||||
def mock_auth_credential():
|
||||
return MagicMock(spec=AuthCredential)
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_apihub_toolset_initialization(basic_apihub_toolset):
|
||||
assert basic_apihub_toolset.name == 'mock_api'
|
||||
assert basic_apihub_toolset.description == 'Mock API Description'
|
||||
assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
|
||||
assert not basic_apihub_toolset.lazy_load_spec
|
||||
assert len(basic_apihub_toolset.generated_tools) == 1
|
||||
assert 'test_get' in basic_apihub_toolset.generated_tools
|
||||
|
||||
|
||||
def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
|
||||
assert lazy_apihub_toolset.lazy_load_spec
|
||||
assert not lazy_apihub_toolset.generated_tools
|
||||
|
||||
tools = lazy_apihub_toolset.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert lazy_apihub_toolset.get_tool('test_get') == tools[0]
|
||||
|
||||
|
||||
def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
|
||||
spec = """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
paths:
|
||||
/empty_desc_test:
|
||||
delete:
|
||||
summary: Test DELETE endpoint
|
||||
operationId: emptyDescTest
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
"""
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return spec
|
||||
|
||||
apihub_client = MockAPIHubClientEmptySpec()
|
||||
toolset = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
|
||||
assert toolset.name == 'unnamed'
|
||||
|
||||
|
||||
def test_apihub_toolset_empty_description_in_spec():
|
||||
spec = """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
title: Empty Description API
|
||||
paths:
|
||||
/empty_desc_test:
|
||||
delete:
|
||||
summary: Test DELETE endpoint
|
||||
operationId: emptyDescTest
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
"""
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return spec
|
||||
|
||||
apihub_client = MockAPIHubClientEmptySpec()
|
||||
toolset = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
|
||||
assert toolset.name == 'empty_description_api'
|
||||
assert toolset.description == ''
|
||||
|
||||
|
||||
def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
auth_scheme=mock_auth_scheme,
|
||||
auth_credential=mock_auth_credential,
|
||||
)
|
||||
tools = tool.get_tools()
|
||||
assert len(tools) == 1
|
||||
|
||||
|
||||
def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return ''
|
||||
|
||||
apihub_client = MockAPIHubClientEmptySpec()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
lazy_load_spec=True,
|
||||
)
|
||||
tools = tool.get_tools()
|
||||
assert not tools
|
||||
|
||||
|
||||
def test_apihub_toolset_get_tools_invalid_yaml():
|
||||
|
||||
class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return '{invalid yaml' # Return invalid YAML
|
||||
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
apihub_client = MockAPIHubClientInvalidYAML()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
tool.get_tools()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
@@ -0,0 +1,600 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
import google.auth
|
||||
import pytest
|
||||
import requests
|
||||
from requests import exceptions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location():
|
||||
return "us-central1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_name():
|
||||
return "test-connection"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials():
|
||||
creds = mock.create_autospec(google.auth.credentials.Credentials)
|
||||
creds.token = "test_token"
|
||||
creds.expired = False
|
||||
return creds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_request():
|
||||
return mock.create_autospec(google.auth.transport.requests.Request)
|
||||
|
||||
|
||||
class TestConnectionsClient:
|
||||
|
||||
def test_initialization(self, project, location, connection_name):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(
|
||||
project, location, connection_name, json.dumps(credentials)
|
||||
)
|
||||
assert client.project == project
|
||||
assert client.location == location
|
||||
assert client.connection == connection_name
|
||||
assert client.connector_url == "https://connectors.googleapis.com"
|
||||
assert client.service_account_json == json.dumps(credentials)
|
||||
assert client.credential_cache is None
|
||||
|
||||
def test_execute_api_call_success(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"data": "test"}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch("requests.get", return_value=mock_response):
|
||||
response = client._execute_api_call("https://test.url")
|
||||
assert response.json() == {"data": "test"}
|
||||
requests.get.assert_called_once_with(
|
||||
"https://test.url",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {mock_credentials.token}",
|
||||
},
|
||||
)
|
||||
|
||||
def test_execute_api_call_credential_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client,
|
||||
"_get_access_token",
|
||||
side_effect=google.auth.exceptions.DefaultCredentialsError("Test"),
|
||||
):
|
||||
with pytest.raises(PermissionError, match="Credentials error: Test"):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, response_text",
|
||||
[(404, "Not Found"), (400, "Bad Request")],
|
||||
)
|
||||
def test_execute_api_call_request_error_not_found_or_bad_request(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
status_code,
|
||||
response_text,
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
f"HTTP error {status_code}: {response_text}"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch("requests.get", return_value=mock_response):
|
||||
with pytest.raises(
|
||||
ValueError, match="Invalid request. Please check the provided"
|
||||
):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
def test_execute_api_call_other_request_error(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
"Internal Server Error"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch("requests.get", return_value=mock_response):
|
||||
with pytest.raises(ValueError, match="Request error: "):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
def test_execute_api_call_unexpected_error(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch(
|
||||
"requests.get", side_effect=Exception("Something went wrong")
|
||||
):
|
||||
with pytest.raises(
|
||||
Exception, match="An unexpected error occurred: Something went wrong"
|
||||
):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
def test_get_connection_details_success_with_host(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"serviceDirectory": "test_service",
|
||||
"host": "test.host",
|
||||
"tlsServiceDirectory": "tls_test_service",
|
||||
"authOverrideEnabled": True,
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_response
|
||||
):
|
||||
details = client.get_connection_details()
|
||||
assert details == {
|
||||
"serviceName": "tls_test_service",
|
||||
"host": "test.host",
|
||||
"authOverrideEnabled": True,
|
||||
}
|
||||
|
||||
def test_get_connection_details_success_without_host(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"serviceDirectory": "test_service",
|
||||
"authOverrideEnabled": False,
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_response
|
||||
):
|
||||
details = client.get_connection_details()
|
||||
assert details == {
|
||||
"serviceName": "test_service",
|
||||
"host": "",
|
||||
"authOverrideEnabled": False,
|
||||
}
|
||||
|
||||
def test_get_connection_details_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", side_effect=ValueError("Request error")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Request error"):
|
||||
client.get_connection_details()
|
||||
|
||||
def test_get_entity_schema_and_operations_success(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response_initial = mock.MagicMock()
|
||||
mock_execute_response_initial.status_code = 200
|
||||
mock_execute_response_initial.json.return_value = {
|
||||
"name": "operations/test_op"
|
||||
}
|
||||
|
||||
mock_execute_response_poll_done = mock.MagicMock()
|
||||
mock_execute_response_poll_done.status_code = 200
|
||||
mock_execute_response_poll_done.json.return_value = {
|
||||
"done": True,
|
||||
"response": {
|
||||
"jsonSchema": {"type": "object"},
|
||||
"operations": ["LIST", "GET"],
|
||||
},
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client,
|
||||
"_execute_api_call",
|
||||
side_effect=[
|
||||
mock_execute_response_initial,
|
||||
mock_execute_response_poll_done,
|
||||
],
|
||||
):
|
||||
schema, operations = client.get_entity_schema_and_operations("entity1")
|
||||
assert schema == {"type": "object"}
|
||||
assert operations == ["LIST", "GET"]
|
||||
assert (
|
||||
mock.call(
|
||||
f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getEntityType?entityId=entity1"
|
||||
)
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
assert (
|
||||
mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
|
||||
def test_get_entity_schema_and_operations_no_operation_id(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response = mock.MagicMock()
|
||||
mock_execute_response.status_code = 200
|
||||
mock_execute_response.json.return_value = {}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_execute_response
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Failed to get entity schema and operations for entity: entity1"
|
||||
),
|
||||
):
|
||||
client.get_entity_schema_and_operations("entity1")
|
||||
|
||||
def test_get_entity_schema_and_operations_execute_api_call_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", side_effect=ValueError("Request error")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Request error"):
|
||||
client.get_entity_schema_and_operations("entity1")
|
||||
|
||||
def test_get_action_schema_success(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response_initial = mock.MagicMock()
|
||||
mock_execute_response_initial.status_code = 200
|
||||
mock_execute_response_initial.json.return_value = {
|
||||
"name": "operations/test_op"
|
||||
}
|
||||
|
||||
mock_execute_response_poll_done = mock.MagicMock()
|
||||
mock_execute_response_poll_done.status_code = 200
|
||||
mock_execute_response_poll_done.json.return_value = {
|
||||
"done": True,
|
||||
"response": {
|
||||
"inputJsonSchema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
},
|
||||
"outputJsonSchema": {
|
||||
"type": "object",
|
||||
"properties": {"output": {"type": "string"}},
|
||||
},
|
||||
"description": "Test Action Description",
|
||||
"displayName": "TestAction",
|
||||
},
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client,
|
||||
"_execute_api_call",
|
||||
side_effect=[
|
||||
mock_execute_response_initial,
|
||||
mock_execute_response_poll_done,
|
||||
],
|
||||
):
|
||||
schema = client.get_action_schema("action1")
|
||||
assert schema == {
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"output": {"type": "string"}},
|
||||
},
|
||||
"description": "Test Action Description",
|
||||
"displayName": "TestAction",
|
||||
}
|
||||
assert (
|
||||
mock.call(
|
||||
f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getAction?actionId=action1"
|
||||
)
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
assert (
|
||||
mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
|
||||
def test_get_action_schema_no_operation_id(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response = mock.MagicMock()
|
||||
mock_execute_response.status_code = 200
|
||||
mock_execute_response.json.return_value = {}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_execute_response
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to get action schema for action: action1"
|
||||
):
|
||||
client.get_action_schema("action1")
|
||||
|
||||
def test_get_action_schema_execute_api_call_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", side_effect=ValueError("Request error")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Request error"):
|
||||
client.get_action_schema("action1")
|
||||
|
||||
def test_get_connector_base_spec(self):
|
||||
spec = ConnectionsClient.get_connector_base_spec()
|
||||
assert "openapi" in spec
|
||||
assert spec["info"]["title"] == "ExecuteConnection"
|
||||
assert "components" in spec
|
||||
assert "schemas" in spec["components"]
|
||||
assert "operation" in spec["components"]["schemas"]
|
||||
|
||||
def test_get_action_operation(self):
|
||||
operation = ConnectionsClient.get_action_operation(
|
||||
"TestAction", "EXECUTE_ACTION", "TestActionDisplayName", "test_tool"
|
||||
)
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "TestActionDisplayName"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_TestActionDisplayName"
|
||||
|
||||
def test_list_operation(self):
|
||||
operation = ConnectionsClient.list_operation(
|
||||
"Entity1", '{"type": "object"}', "test_tool"
|
||||
)
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "List Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_list_Entity1"
|
||||
|
||||
def test_get_operation_static(self):
|
||||
operation = ConnectionsClient.get_operation(
|
||||
"Entity1", '{"type": "object"}', "test_tool"
|
||||
)
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Get Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_get_Entity1"
|
||||
|
||||
def test_create_operation(self):
|
||||
operation = ConnectionsClient.create_operation("Entity1", "test_tool")
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Create Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_create_Entity1"
|
||||
|
||||
def test_update_operation(self):
|
||||
operation = ConnectionsClient.update_operation("Entity1", "test_tool")
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Update Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_update_Entity1"
|
||||
|
||||
def test_delete_operation(self):
|
||||
operation = ConnectionsClient.delete_operation("Entity1", "test_tool")
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Delete Entity1"
|
||||
assert operation["post"]["operationId"] == "test_tool_delete_Entity1"
|
||||
|
||||
def test_create_operation_request(self):
|
||||
schema = ConnectionsClient.create_operation_request("Entity1")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "connectorInputPayload" in schema["properties"]
|
||||
|
||||
def test_update_operation_request(self):
|
||||
schema = ConnectionsClient.update_operation_request("Entity1")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "entityId" in schema["properties"]
|
||||
|
||||
def test_get_operation_request_static(self):
|
||||
schema = ConnectionsClient.get_operation_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "entityId" in schema["properties"]
|
||||
|
||||
def test_delete_operation_request(self):
|
||||
schema = ConnectionsClient.delete_operation_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "entityId" in schema["properties"]
|
||||
|
||||
def test_list_operation_request(self):
|
||||
schema = ConnectionsClient.list_operation_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "filterClause" in schema["properties"]
|
||||
|
||||
def test_action_request(self):
|
||||
schema = ConnectionsClient.action_request("TestAction")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "connectorInputPayload" in schema["properties"]
|
||||
|
||||
def test_action_response(self):
|
||||
schema = ConnectionsClient.action_response("TestAction")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "connectorOutputPayload" in schema["properties"]
|
||||
|
||||
def test_execute_custom_query_request(self):
|
||||
schema = ConnectionsClient.execute_custom_query_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
def test_connector_payload(self):
|
||||
client = ConnectionsClient("test-project", "us-central1", "test-connection")
|
||||
schema = client.connector_payload(
|
||||
json_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": ["null", "string"],
|
||||
"description": "description",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
assert schema == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"nullable": True,
|
||||
"description": "description",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def test_get_access_token_uses_cached_token(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
client.credential_cache = mock_credentials
|
||||
token = client._get_access_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_access_token_with_service_account_credentials(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
service_account_json = json.dumps({
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "test_key",
|
||||
})
|
||||
client = ConnectionsClient(
|
||||
project, location, connection_name, service_account_json
|
||||
)
|
||||
mock_creds = mock.create_autospec(google.oauth2.service_account.Credentials)
|
||||
mock_creds.token = "sa_token"
|
||||
mock_creds.expired = False
|
||||
|
||||
with mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||
return_value=mock_creds,
|
||||
), mock.patch.object(mock_creds, "refresh", return_value=None):
|
||||
token = client._get_access_token()
|
||||
assert token == "sa_token"
|
||||
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_creds.refresh.assert_called_once()
|
||||
|
||||
def test_get_access_token_with_default_credentials(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
), mock.patch.object(mock_credentials, "refresh", return_value=None):
|
||||
token = client._get_access_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_access_token_no_valid_credentials(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
return_value=(None, None),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account that has the required"
|
||||
" permissions"
|
||||
),
|
||||
):
|
||||
client._get_access_token()
|
||||
|
||||
def test_get_access_token_refreshes_expired_token(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
mock_credentials.expired = True
|
||||
mock_credentials.token = "old_token"
|
||||
mock_credentials.refresh.return_value = None
|
||||
|
||||
client.credential_cache = mock_credentials
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
):
|
||||
# Mock the refresh method directly on the instance within the context
|
||||
with mock.patch.object(mock_credentials, "refresh") as mock_refresh:
|
||||
mock_credentials.token = "new_token" # Set the expected new token
|
||||
token = client._get_access_token()
|
||||
assert token == "new_token"
|
||||
mock_refresh.assert_called_once()
|
||||
@@ -0,0 +1,630 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import pytest
|
||||
import requests
|
||||
from requests import exceptions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location():
|
||||
return "us-central1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_name():
|
||||
return "test-integration"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trigger_name():
|
||||
return "test-trigger"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_name():
|
||||
return "test-connection"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials():
|
||||
creds = mock.create_autospec(google.auth.credentials.Credentials)
|
||||
creds.token = "test_token"
|
||||
return creds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_request():
|
||||
return mock.create_autospec(Request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connections_client():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.ConnectionsClient"
|
||||
) as mock_client:
|
||||
mock_instance = mock.create_autospec(ConnectionsClient)
|
||||
mock_client.return_value = mock_instance
|
||||
yield mock_client
|
||||
|
||||
|
||||
class TestIntegrationClient:
|
||||
|
||||
def test_initialization(
|
||||
self, project, location, integration_name, trigger_name, connection_name
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations={"entity": ["LIST"]},
|
||||
actions=["action1"],
|
||||
service_account_json=json.dumps({"email": "test@example.com"}),
|
||||
)
|
||||
assert client.project == project
|
||||
assert client.location == location
|
||||
assert client.integration == integration_name
|
||||
assert client.trigger == trigger_name
|
||||
assert client.connection == connection_name
|
||||
assert client.entity_operations == {"entity": ["LIST"]}
|
||||
assert client.actions == ["action1"]
|
||||
assert client.service_account_json == json.dumps(
|
||||
{"email": "test@example.com"}
|
||||
)
|
||||
assert client.credential_cache is None
|
||||
|
||||
def test_get_openapi_spec_for_integration_success(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
|
||||
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch("requests.post", return_value=mock_response):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
spec = client.get_openapi_spec_for_integration()
|
||||
assert spec == expected_spec
|
||||
requests.post.assert_called_once_with(
|
||||
f"https://{location}-integrations.googleapis.com/v1/projects/{project}/locations/{location}:generateOpenApiSpec",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {mock_credentials.token}",
|
||||
},
|
||||
json={
|
||||
"apiTriggerResources": [{
|
||||
"integrationResource": integration_name,
|
||||
"triggerId": [trigger_name],
|
||||
}],
|
||||
"fileFormat": "JSON",
|
||||
},
|
||||
)
|
||||
|
||||
def test_get_openapi_spec_for_integration_credential_error(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_connections_client,
|
||||
):
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
side_effect=ValueError(
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
),
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"An unexpected error occurred: Please provide a service account"
|
||||
" that has the required permissions to access the connection."
|
||||
),
|
||||
):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, response_text",
|
||||
[(404, "Not Found"), (400, "Bad Request"), (404, ""), (400, "")],
|
||||
)
|
||||
def test_get_openapi_spec_for_integration_request_error_not_found_or_bad_request(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
status_code,
|
||||
response_text,
|
||||
mock_connections_client,
|
||||
):
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
f"HTTP error {status_code}: {response_text}"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch("requests.post", return_value=mock_response):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Invalid request. Please check the provided values of"
|
||||
f" project\\({project}\\), location\\({location}\\),"
|
||||
f" integration\\({integration_name}\\) and"
|
||||
f" trigger\\({trigger_name}\\)."
|
||||
),
|
||||
):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
def test_get_openapi_spec_for_integration_other_request_error(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
"Internal Server Error"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch("requests.post", return_value=mock_response):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Request error: "):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
def test_get_openapi_spec_for_integration_unexpected_error(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch(
|
||||
"requests.post", side_effect=Exception("Something went wrong")
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
Exception, match="An unexpected error occurred: Something went wrong"
|
||||
):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
def test_get_openapi_spec_for_connection_no_entity_operations_or_actions(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No entity operations or actions provided. Please provide at least"
|
||||
" one of them."
|
||||
),
|
||||
):
|
||||
client.get_openapi_spec_for_connection()
|
||||
|
||||
def test_get_openapi_spec_for_connection_with_entity_operations(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
entity_operations = {"entity1": ["LIST", "GET"]}
|
||||
|
||||
mock_connections_client_instance = mock_connections_client.return_value
|
||||
mock_connections_client_instance.get_connector_base_spec.return_value = {
|
||||
"components": {"schemas": {}},
|
||||
"paths": {},
|
||||
}
|
||||
mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
|
||||
{"type": "object", "properties": {"id": {"type": "string"}}},
|
||||
["LIST", "GET"],
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.list_operation.return_value = {"get": {}}
|
||||
mock_connections_client_instance.list_operation_request.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.get_operation.return_value = {"get": {}}
|
||||
mock_connections_client_instance.get_operation_request.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
spec = client.get_openapi_spec_for_connection()
|
||||
assert "paths" in spec
|
||||
assert (
|
||||
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#list_entity1"
|
||||
in spec["paths"]
|
||||
)
|
||||
assert (
|
||||
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#get_entity1"
|
||||
in spec["paths"]
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client_instance.get_connector_base_spec.assert_called_once()
|
||||
mock_connections_client_instance.get_entity_schema_and_operations.assert_any_call(
|
||||
"entity1"
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.assert_any_call(
|
||||
{"type": "object", "properties": {"id": {"type": "string"}}}
|
||||
)
|
||||
mock_connections_client_instance.list_operation.assert_called_once()
|
||||
mock_connections_client_instance.get_operation.assert_called_once()
|
||||
|
||||
def test_get_openapi_spec_for_connection_with_actions(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
actions = ["TestAction"]
|
||||
mock_connections_client_instance = (
|
||||
mock_connections_client.return_value
|
||||
) # Corrected line
|
||||
mock_connections_client_instance.get_connector_base_spec.return_value = {
|
||||
"components": {"schemas": {}},
|
||||
"paths": {},
|
||||
}
|
||||
mock_connections_client_instance.get_action_schema.return_value = {
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"output": {"type": "string"}},
|
||||
},
|
||||
"displayName": "TestAction",
|
||||
}
|
||||
mock_connections_client_instance.connector_payload.side_effect = [
|
||||
{"type": "object"},
|
||||
{"type": "object"},
|
||||
]
|
||||
mock_connections_client_instance.action_request.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.action_response.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.get_action_operation.return_value = {
|
||||
"post": {}
|
||||
}
|
||||
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=actions,
|
||||
service_account_json=None,
|
||||
)
|
||||
spec = client.get_openapi_spec_for_connection()
|
||||
assert "paths" in spec
|
||||
assert (
|
||||
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#TestAction"
|
||||
in spec["paths"]
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client_instance.get_connector_base_spec.assert_called_once()
|
||||
mock_connections_client_instance.get_action_schema.assert_called_once_with(
|
||||
"TestAction"
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.assert_any_call(
|
||||
{"type": "object", "properties": {"input": {"type": "string"}}}
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.assert_any_call(
|
||||
{"type": "object", "properties": {"output": {"type": "string"}}}
|
||||
)
|
||||
mock_connections_client_instance.action_request.assert_called_once_with(
|
||||
"TestAction"
|
||||
)
|
||||
mock_connections_client_instance.action_response.assert_called_once_with(
|
||||
"TestAction"
|
||||
)
|
||||
mock_connections_client_instance.get_action_operation.assert_called_once()
|
||||
|
||||
def test_get_openapi_spec_for_connection_invalid_operation(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
entity_operations = {"entity1": ["INVALID"]}
|
||||
mock_connections_client_instance = mock_connections_client.return_value
|
||||
mock_connections_client_instance.get_connector_base_spec.return_value = {
|
||||
"components": {"schemas": {}},
|
||||
"paths": {},
|
||||
}
|
||||
mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
|
||||
{"type": "object", "properties": {"id": {"type": "string"}}},
|
||||
["LIST", "GET"],
|
||||
)
|
||||
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="Invalid operation: INVALID for entity: entity1"
|
||||
):
|
||||
client.get_openapi_spec_for_connection()
|
||||
|
||||
def test_get_access_token_with_service_account_json(
|
||||
self, project, location, integration_name, trigger_name, connection_name
|
||||
):
|
||||
service_account_json = json.dumps({
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "test_key",
|
||||
})
|
||||
mock_creds = mock.create_autospec(service_account.Credentials)
|
||||
mock_creds.token = "sa_token"
|
||||
mock_creds.expired = False
|
||||
|
||||
with mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||
return_value=mock_creds,
|
||||
), mock.patch.object(mock_creds, "refresh", return_value=None):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
token = client._get_access_token()
|
||||
assert token == "sa_token"
|
||||
service_account.Credentials.from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_creds.refresh.assert_called_once()
|
||||
|
||||
def test_get_access_token_with_default_credentials(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
):
|
||||
mock_credentials.expired = False
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
), mock.patch.object(mock_credentials, "refresh", return_value=None):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
token = client._get_access_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_access_token_no_valid_credentials(
|
||||
self, project, location, integration_name, trigger_name, connection_name
|
||||
):
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
return_value=(None, None),
|
||||
), mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||
return_value=None,
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
try:
|
||||
client._get_access_token()
|
||||
assert False, "ValueError was not raised" # Explicitly fail if no error
|
||||
except ValueError as e:
|
||||
assert (
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
in str(e)
|
||||
)
|
||||
|
||||
def test_get_access_token_uses_cached_token(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
):
|
||||
mock_credentials.token = "cached_token"
|
||||
mock_credentials.expired = False
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
client.credential_cache = mock_credentials # Simulate a cached credential
|
||||
with mock.patch("google.auth.default") as mock_default, mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info"
|
||||
) as mock_sa:
|
||||
token = client._get_access_token()
|
||||
assert token == "cached_token"
|
||||
mock_default.assert_not_called()
|
||||
mock_sa.assert_not_called()
|
||||
|
||||
def test_get_access_token_refreshes_expired_token(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
):
|
||||
mock_credentials = mock.create_autospec(google.auth.credentials.Credentials)
|
||||
mock_credentials.token = "old_token"
|
||||
mock_credentials.expired = True
|
||||
mock_credentials.refresh.return_value = None
|
||||
mock_credentials.token = "new_token" # Simulate token refresh
|
||||
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
client.credential_cache = mock_credentials
|
||||
token = client._get_access_token()
|
||||
assert token == "new_token"
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_integration_client():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.IntegrationClient"
|
||||
) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connections_client():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.ConnectionsClient"
|
||||
) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openapi_toolset():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenAPIToolset"
|
||||
) as mock_toolset:
|
||||
mock_toolset_instance = mock.MagicMock()
|
||||
mock_rest_api_tool = mock.MagicMock(spec=rest_api_tool.RestApiTool)
|
||||
mock_rest_api_tool.name = "Test Tool"
|
||||
mock_toolset_instance.get_tools.return_value = [mock_rest_api_tool]
|
||||
mock_toolset.return_value = mock_toolset_instance
|
||||
yield mock_toolset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location():
|
||||
return "us-central1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_spec():
|
||||
return {"openapi": "3.0.0", "info": {"title": "Integration API"}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_spec():
|
||||
return {"openapi": "3.0.0", "info": {"title": "Connection API"}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_details():
|
||||
return {"serviceName": "test-service", "host": "test.host"}
|
||||
|
||||
|
||||
def test_initialization_with_integration_and_trigger(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
):
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project, location, integration=integration_name, trigger=trigger_name
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, integration_name, trigger_name, None, None, None, None
|
||||
)
|
||||
mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once()
|
||||
mock_connections_client.assert_not_called()
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_with_connection_and_entity_operations(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
entity_operations_list = ["list", "get"]
|
||||
tool_name = "My Connection Tool"
|
||||
tool_instructions = "Use this tool to manage entities."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||
connection_details
|
||||
)
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project,
|
||||
location,
|
||||
None,
|
||||
None,
|
||||
connection_name,
|
||||
entity_operations_list,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
||||
f" {connection_details['host']} and the connection name ="
|
||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
||||
" using this tool. DONOT ask the user for these values as you already"
|
||||
" have those.",
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_with_connection_and_actions(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
actions_list = ["create", "delete"]
|
||||
tool_name = "My Actions Tool"
|
||||
tool_instructions = "Perform actions using this tool."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||
connection_details
|
||||
)
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
actions=actions_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, None, None, connection_name, None, actions_list, None
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
||||
f" {connection_details['host']} and the connection name ="
|
||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
||||
" using this tool. DONOT ask the user for these values as you already"
|
||||
" have those.",
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_without_required_params(project, location):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location, integration="test")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location, trigger="test")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location, connection="test")
|
||||
|
||||
|
||||
def test_initialization_with_service_account_credentials(
|
||||
project, location, mock_integration_client, mock_openapi_toolset
|
||||
):
|
||||
service_account_json = json.dumps({
|
||||
"type": "service_account",
|
||||
"project_id": "dummy",
|
||||
"private_key_id": "dummy",
|
||||
"private_key": "dummy",
|
||||
"client_email": "test@example.com",
|
||||
"client_id": "131331543646416",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": (
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
"client_x509_cert_url": (
|
||||
"http://www.googleapis.com/robot/v1/metadata/x509/dummy%40dummy.com"
|
||||
),
|
||||
"universe_domain": "googleapis.com",
|
||||
})
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
service_account_json,
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
_, kwargs = mock_openapi_toolset.call_args
|
||||
assert isinstance(kwargs["auth_credential"], AuthCredential)
|
||||
assert (
|
||||
kwargs[
|
||||
"auth_credential"
|
||||
].service_account.service_account_credential.client_email
|
||||
== "test@example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_initialization_without_explicit_service_account_credentials(
|
||||
project, location, mock_integration_client, mock_openapi_toolset
|
||||
):
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project, location, integration=integration_name, trigger=trigger_name
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, integration_name, trigger_name, None, None, None, None
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
_, kwargs = mock_openapi_toolset.call_args
|
||||
assert isinstance(kwargs["auth_credential"], AuthCredential)
|
||||
assert kwargs["auth_credential"].service_account.use_default_credential
|
||||
|
||||
|
||||
def test_get_tools(
|
||||
project, location, mock_integration_client, mock_openapi_toolset
|
||||
):
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project, location, integration=integration_name, trigger=trigger_name
|
||||
)
|
||||
tools = toolset.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], rest_api_tool.RestApiTool)
|
||||
assert tools[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_with_connection_details(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
entity_operations_list = ["list"]
|
||||
tool_name = "My Connection Tool"
|
||||
tool_instructions = "Use this tool."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = {
|
||||
"serviceName": "custom-service",
|
||||
"host": "custom.host",
|
||||
}
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
)
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
|
||||
" connection name ="
|
||||
" projects/test-project/locations/us-central1/connections/test-connection"
|
||||
" when using this tool. DONOT ask the user for these values as you"
|
||||
" already have those.",
|
||||
)
|
||||
13
tests/unittests/tools/google_api_tool/__init__.py
Normal file
13
tests/unittests/tools/google_api_tool/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,657 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.tools.google_api_tool.googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||
# Import the converter class
|
||||
from googleapiclient.errors import HttpError
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calendar_api_spec():
|
||||
"""Fixture that provides a mock Google Calendar API spec for testing."""
|
||||
return {
|
||||
"kind": "discovery#restDescription",
|
||||
"id": "calendar:v3",
|
||||
"name": "calendar",
|
||||
"version": "v3",
|
||||
"title": "Google Calendar API",
|
||||
"description": "Accesses the Google Calendar API",
|
||||
"documentationLink": "https://developers.google.com/calendar/",
|
||||
"protocol": "rest",
|
||||
"rootUrl": "https://www.googleapis.com/",
|
||||
"servicePath": "calendar/v3/",
|
||||
"auth": {
|
||||
"oauth2": {
|
||||
"scopes": {
|
||||
"https://www.googleapis.com/auth/calendar": {
|
||||
"description": "Full access to Google Calendar"
|
||||
},
|
||||
"https://www.googleapis.com/auth/calendar.readonly": {
|
||||
"description": "Read-only access to Google Calendar"
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"schemas": {
|
||||
"Calendar": {
|
||||
"type": "object",
|
||||
"description": "A calendar resource",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Calendar identifier",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Calendar summary",
|
||||
"required": True,
|
||||
},
|
||||
"timeZone": {
|
||||
"type": "string",
|
||||
"description": "Calendar timezone",
|
||||
},
|
||||
},
|
||||
},
|
||||
"Event": {
|
||||
"type": "object",
|
||||
"description": "An event resource",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "Event identifier"},
|
||||
"summary": {"type": "string", "description": "Event summary"},
|
||||
"start": {"$ref": "EventDateTime"},
|
||||
"end": {"$ref": "EventDateTime"},
|
||||
"attendees": {
|
||||
"type": "array",
|
||||
"description": "Event attendees",
|
||||
"items": {"$ref": "EventAttendee"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"EventDateTime": {
|
||||
"type": "object",
|
||||
"description": "Date/time for an event",
|
||||
"properties": {
|
||||
"dateTime": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Date/time in RFC3339 format",
|
||||
},
|
||||
"timeZone": {
|
||||
"type": "string",
|
||||
"description": "Timezone for the date/time",
|
||||
},
|
||||
},
|
||||
},
|
||||
"EventAttendee": {
|
||||
"type": "object",
|
||||
"description": "An attendee of an event",
|
||||
"properties": {
|
||||
"email": {"type": "string", "description": "Attendee email"},
|
||||
"responseStatus": {
|
||||
"type": "string",
|
||||
"description": "Response status",
|
||||
"enum": [
|
||||
"needsAction",
|
||||
"declined",
|
||||
"tentative",
|
||||
"accepted",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"resources": {
|
||||
"calendars": {
|
||||
"methods": {
|
||||
"get": {
|
||||
"id": "calendar.calendars.get",
|
||||
"path": "calendars/{calendarId}",
|
||||
"httpMethod": "GET",
|
||||
"description": "Returns metadata for a calendar.",
|
||||
"parameters": {
|
||||
"calendarId": {
|
||||
"type": "string",
|
||||
"description": "Calendar identifier",
|
||||
"required": True,
|
||||
"location": "path",
|
||||
}
|
||||
},
|
||||
"response": {"$ref": "Calendar"},
|
||||
"scopes": [
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
],
|
||||
},
|
||||
"insert": {
|
||||
"id": "calendar.calendars.insert",
|
||||
"path": "calendars",
|
||||
"httpMethod": "POST",
|
||||
"description": "Creates a secondary calendar.",
|
||||
"request": {"$ref": "Calendar"},
|
||||
"response": {"$ref": "Calendar"},
|
||||
"scopes": ["https://www.googleapis.com/auth/calendar"],
|
||||
},
|
||||
},
|
||||
"resources": {
|
||||
"events": {
|
||||
"methods": {
|
||||
"list": {
|
||||
"id": "calendar.events.list",
|
||||
"path": "calendars/{calendarId}/events",
|
||||
"httpMethod": "GET",
|
||||
"description": (
|
||||
"Returns events on the specified calendar."
|
||||
),
|
||||
"parameters": {
|
||||
"calendarId": {
|
||||
"type": "string",
|
||||
"description": "Calendar identifier",
|
||||
"required": True,
|
||||
"location": "path",
|
||||
},
|
||||
"maxResults": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of events returned"
|
||||
),
|
||||
"format": "int32",
|
||||
"minimum": "1",
|
||||
"maximum": "2500",
|
||||
"default": "250",
|
||||
"location": "query",
|
||||
},
|
||||
"orderBy": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Order of the events returned"
|
||||
),
|
||||
"enum": ["startTime", "updated"],
|
||||
"location": "query",
|
||||
},
|
||||
},
|
||||
"response": {"$ref": "Events"},
|
||||
"scopes": [
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter():
|
||||
"""Fixture that provides a basic converter instance."""
|
||||
return GoogleApiToOpenApiConverter("calendar", "v3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_resource(calendar_api_spec):
|
||||
"""Fixture that provides a mock API resource with the test spec."""
|
||||
mock_resource = MagicMock()
|
||||
mock_resource._rootDesc = calendar_api_spec
|
||||
return mock_resource
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepared_converter(converter, calendar_api_spec):
|
||||
"""Fixture that provides a converter with the API spec already set."""
|
||||
converter.google_api_spec = calendar_api_spec
|
||||
return converter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter_with_patched_build(monkeypatch, mock_api_resource):
|
||||
"""Fixture that provides a converter with the build function patched.
|
||||
|
||||
This simulates a successful API spec fetch.
|
||||
"""
|
||||
# Create a mock for the build function
|
||||
mock_build = MagicMock(return_value=mock_api_resource)
|
||||
|
||||
# Patch the build function in the target module
|
||||
monkeypatch.setattr(
|
||||
"google.adk.tools.google_api_tool.googleapi_to_openapi_converter.build",
|
||||
mock_build,
|
||||
)
|
||||
|
||||
# Create and return a converter instance
|
||||
return GoogleApiToOpenApiConverter("calendar", "v3")
|
||||
|
||||
|
||||
class TestGoogleApiToOpenApiConverter:
|
||||
"""Test suite for the GoogleApiToOpenApiConverter class."""
|
||||
|
||||
def test_init(self, converter):
|
||||
"""Test converter initialization."""
|
||||
assert converter.api_name == "calendar"
|
||||
assert converter.api_version == "v3"
|
||||
assert converter.google_api_resource is None
|
||||
assert converter.google_api_spec is None
|
||||
assert converter.openapi_spec["openapi"] == "3.0.0"
|
||||
assert "info" in converter.openapi_spec
|
||||
assert "paths" in converter.openapi_spec
|
||||
assert "components" in converter.openapi_spec
|
||||
|
||||
def test_fetch_google_api_spec(
|
||||
self, converter_with_patched_build, calendar_api_spec
|
||||
):
|
||||
"""Test fetching Google API specification."""
|
||||
# Call the method
|
||||
converter_with_patched_build.fetch_google_api_spec()
|
||||
|
||||
# Verify the results
|
||||
assert converter_with_patched_build.google_api_spec == calendar_api_spec
|
||||
|
||||
def test_fetch_google_api_spec_error(self, monkeypatch, converter):
|
||||
"""Test error handling when fetching Google API specification."""
|
||||
# Create a mock that raises an error
|
||||
mock_build = MagicMock(
|
||||
side_effect=HttpError(resp=MagicMock(status=404), content=b"Not Found")
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"google.adk.tools.google_api_tool.googleapi_to_openapi_converter.build",
|
||||
mock_build,
|
||||
)
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(HttpError):
|
||||
converter.fetch_google_api_spec()
|
||||
|
||||
def test_convert_info(self, prepared_converter):
|
||||
"""Test conversion of basic API information."""
|
||||
# Call the method
|
||||
prepared_converter._convert_info()
|
||||
|
||||
# Verify the results
|
||||
info = prepared_converter.openapi_spec["info"]
|
||||
assert info["title"] == "Google Calendar API"
|
||||
assert info["description"] == "Accesses the Google Calendar API"
|
||||
assert info["version"] == "v3"
|
||||
assert info["termsOfService"] == "https://developers.google.com/calendar/"
|
||||
|
||||
# Check external docs
|
||||
external_docs = prepared_converter.openapi_spec["externalDocs"]
|
||||
assert external_docs["url"] == "https://developers.google.com/calendar/"
|
||||
|
||||
def test_convert_servers(self, prepared_converter):
|
||||
"""Test conversion of server information."""
|
||||
# Call the method
|
||||
prepared_converter._convert_servers()
|
||||
|
||||
# Verify the results
|
||||
servers = prepared_converter.openapi_spec["servers"]
|
||||
assert len(servers) == 1
|
||||
assert servers[0]["url"] == "https://www.googleapis.com/calendar/v3"
|
||||
assert servers[0]["description"] == "calendar v3 API"
|
||||
|
||||
def test_convert_security_schemes(self, prepared_converter):
|
||||
"""Test conversion of security schemes."""
|
||||
# Call the method
|
||||
prepared_converter._convert_security_schemes()
|
||||
|
||||
# Verify the results
|
||||
security_schemes = prepared_converter.openapi_spec["components"][
|
||||
"securitySchemes"
|
||||
]
|
||||
|
||||
# Check OAuth2 configuration
|
||||
assert "oauth2" in security_schemes
|
||||
oauth2 = security_schemes["oauth2"]
|
||||
assert oauth2["type"] == "oauth2"
|
||||
|
||||
# Check OAuth2 scopes
|
||||
scopes = oauth2["flows"]["authorizationCode"]["scopes"]
|
||||
assert "https://www.googleapis.com/auth/calendar" in scopes
|
||||
assert "https://www.googleapis.com/auth/calendar.readonly" in scopes
|
||||
|
||||
# Check API key configuration
|
||||
assert "apiKey" in security_schemes
|
||||
assert security_schemes["apiKey"]["type"] == "apiKey"
|
||||
assert security_schemes["apiKey"]["in"] == "query"
|
||||
assert security_schemes["apiKey"]["name"] == "key"
|
||||
|
||||
def test_convert_schemas(self, prepared_converter):
|
||||
"""Test conversion of schema definitions."""
|
||||
# Call the method
|
||||
prepared_converter._convert_schemas()
|
||||
|
||||
# Verify the results
|
||||
schemas = prepared_converter.openapi_spec["components"]["schemas"]
|
||||
|
||||
# Check Calendar schema
|
||||
assert "Calendar" in schemas
|
||||
calendar_schema = schemas["Calendar"]
|
||||
assert calendar_schema["type"] == "object"
|
||||
assert calendar_schema["description"] == "A calendar resource"
|
||||
|
||||
# Check required properties
|
||||
assert "required" in calendar_schema
|
||||
assert "summary" in calendar_schema["required"]
|
||||
|
||||
# Check Event schema references
|
||||
assert "Event" in schemas
|
||||
event_schema = schemas["Event"]
|
||||
assert (
|
||||
event_schema["properties"]["start"]["$ref"]
|
||||
== "#/components/schemas/EventDateTime"
|
||||
)
|
||||
|
||||
# Check array type with references
|
||||
attendees_schema = event_schema["properties"]["attendees"]
|
||||
assert attendees_schema["type"] == "array"
|
||||
assert (
|
||||
attendees_schema["items"]["$ref"]
|
||||
== "#/components/schemas/EventAttendee"
|
||||
)
|
||||
|
||||
# Check enum values
|
||||
attendee_schema = schemas["EventAttendee"]
|
||||
response_status = attendee_schema["properties"]["responseStatus"]
|
||||
assert "enum" in response_status
|
||||
assert "accepted" in response_status["enum"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_def, expected_type, expected_attrs",
|
||||
[
|
||||
# Test object type
|
||||
(
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Test object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "required": True},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"object",
|
||||
{"description": "Test object", "required": ["id"]},
|
||||
),
|
||||
# Test array type
|
||||
(
|
||||
{
|
||||
"type": "array",
|
||||
"description": "Test array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"array",
|
||||
{"description": "Test array", "items": {"type": "string"}},
|
||||
),
|
||||
# Test reference conversion
|
||||
(
|
||||
{"$ref": "Calendar"},
|
||||
None, # No type for references
|
||||
{"$ref": "#/components/schemas/Calendar"},
|
||||
),
|
||||
# Test enum conversion
|
||||
(
|
||||
{"type": "string", "enum": ["value1", "value2"]},
|
||||
"string",
|
||||
{"enum": ["value1", "value2"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_schema_object(
|
||||
self, converter, schema_def, expected_type, expected_attrs
|
||||
):
|
||||
"""Test conversion of individual schema objects with different input variations."""
|
||||
converted = converter._convert_schema_object(schema_def)
|
||||
|
||||
# Check type if expected
|
||||
if expected_type:
|
||||
assert converted["type"] == expected_type
|
||||
|
||||
# Check other expected attributes
|
||||
for key, value in expected_attrs.items():
|
||||
assert converted[key] == value
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path, expected_params",
|
||||
[
|
||||
# Path with parameters
|
||||
(
|
||||
"/calendars/{calendarId}/events/{eventId}",
|
||||
["calendarId", "eventId"],
|
||||
),
|
||||
# Path without parameters
|
||||
("/calendars/events", []),
|
||||
# Mixed path
|
||||
("/users/{userId}/calendars/default", ["userId"]),
|
||||
],
|
||||
)
|
||||
def test_extract_path_parameters(self, converter, path, expected_params):
|
||||
"""Test extraction of path parameters from URL path with various inputs."""
|
||||
params = converter._extract_path_parameters(path)
|
||||
assert set(params) == set(expected_params)
|
||||
assert len(params) == len(expected_params)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_data, expected_result",
|
||||
[
|
||||
# String parameter
|
||||
(
|
||||
{
|
||||
"type": "string",
|
||||
"description": "String parameter",
|
||||
"pattern": "^[a-z]+$",
|
||||
},
|
||||
{"type": "string", "pattern": "^[a-z]+$"},
|
||||
),
|
||||
# Integer parameter with format
|
||||
(
|
||||
{"type": "integer", "format": "int32", "default": "10"},
|
||||
{"type": "integer", "format": "int32", "default": "10"},
|
||||
),
|
||||
# Enum parameter
|
||||
(
|
||||
{"type": "string", "enum": ["option1", "option2"]},
|
||||
{"type": "string", "enum": ["option1", "option2"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_parameter_schema(
|
||||
self, converter, param_data, expected_result
|
||||
):
|
||||
"""Test conversion of parameter definitions to OpenAPI schemas."""
|
||||
converted = converter._convert_parameter_schema(param_data)
|
||||
|
||||
# Check all expected attributes
|
||||
for key, value in expected_result.items():
|
||||
assert converted[key] == value
|
||||
|
||||
def test_convert(self, converter_with_patched_build):
|
||||
"""Test the complete conversion process."""
|
||||
# Call the method
|
||||
result = converter_with_patched_build.convert()
|
||||
|
||||
# Verify basic structure
|
||||
assert result["openapi"] == "3.0.0"
|
||||
assert "info" in result
|
||||
assert "servers" in result
|
||||
assert "paths" in result
|
||||
assert "components" in result
|
||||
|
||||
# Verify paths
|
||||
paths = result["paths"]
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
assert "get" in paths["/calendars/{calendarId}"]
|
||||
|
||||
# Verify nested resources
|
||||
assert "/calendars/{calendarId}/events" in paths
|
||||
|
||||
# Verify method details
|
||||
get_calendar = paths["/calendars/{calendarId}"]["get"]
|
||||
assert get_calendar["operationId"] == "calendar.calendars.get"
|
||||
assert "parameters" in get_calendar
|
||||
|
||||
# Verify request body
|
||||
insert_calendar = paths["/calendars"]["post"]
|
||||
assert "requestBody" in insert_calendar
|
||||
request_schema = insert_calendar["requestBody"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert request_schema["$ref"] == "#/components/schemas/Calendar"
|
||||
|
||||
# Verify response body
|
||||
assert "responses" in get_calendar
|
||||
response_schema = get_calendar["responses"]["200"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert response_schema["$ref"] == "#/components/schemas/Calendar"
|
||||
|
||||
def test_convert_methods(self, prepared_converter, calendar_api_spec):
|
||||
"""Test conversion of API methods."""
|
||||
# Convert methods
|
||||
methods = calendar_api_spec["resources"]["calendars"]["methods"]
|
||||
prepared_converter._convert_methods(methods, "/calendars")
|
||||
|
||||
# Verify the results
|
||||
paths = prepared_converter.openapi_spec["paths"]
|
||||
|
||||
# Check GET method
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
get_method = paths["/calendars/{calendarId}"]["get"]
|
||||
assert get_method["operationId"] == "calendar.calendars.get"
|
||||
|
||||
# Check parameters
|
||||
params = get_method["parameters"]
|
||||
param_names = [p["name"] for p in params]
|
||||
assert "calendarId" in param_names
|
||||
|
||||
# Check POST method
|
||||
assert "/calendars" in paths
|
||||
post_method = paths["/calendars"]["post"]
|
||||
assert post_method["operationId"] == "calendar.calendars.insert"
|
||||
|
||||
# Check request body
|
||||
assert "requestBody" in post_method
|
||||
assert (
|
||||
post_method["requestBody"]["content"]["application/json"]["schema"][
|
||||
"$ref"
|
||||
]
|
||||
== "#/components/schemas/Calendar"
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert (
|
||||
post_method["responses"]["200"]["content"]["application/json"][
|
||||
"schema"
|
||||
]["$ref"]
|
||||
== "#/components/schemas/Calendar"
|
||||
)
|
||||
|
||||
def test_convert_resources(self, prepared_converter, calendar_api_spec):
|
||||
"""Test conversion of nested resources."""
|
||||
# Convert resources
|
||||
resources = calendar_api_spec["resources"]
|
||||
prepared_converter._convert_resources(resources)
|
||||
|
||||
# Verify the results
|
||||
paths = prepared_converter.openapi_spec["paths"]
|
||||
|
||||
# Check top-level resource methods
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
|
||||
# Check nested resource methods
|
||||
assert "/calendars/{calendarId}/events" in paths
|
||||
events_method = paths["/calendars/{calendarId}/events"]["get"]
|
||||
assert events_method["operationId"] == "calendar.events.list"
|
||||
|
||||
# Check parameters in nested resource
|
||||
params = events_method["parameters"]
|
||||
param_names = [p["name"] for p in params]
|
||||
assert "calendarId" in param_names
|
||||
assert "maxResults" in param_names
|
||||
assert "orderBy" in param_names
|
||||
|
||||
def test_integration_calendar_api(self, converter_with_patched_build):
|
||||
"""Integration test using Calendar API specification."""
|
||||
# Create and run the converter
|
||||
openapi_spec = converter_with_patched_build.convert()
|
||||
|
||||
# Verify conversion results
|
||||
assert openapi_spec["info"]["title"] == "Google Calendar API"
|
||||
assert (
|
||||
openapi_spec["servers"][0]["url"]
|
||||
== "https://www.googleapis.com/calendar/v3"
|
||||
)
|
||||
|
||||
# Check security schemes
|
||||
security_schemes = openapi_spec["components"]["securitySchemes"]
|
||||
assert "oauth2" in security_schemes
|
||||
assert "apiKey" in security_schemes
|
||||
|
||||
# Check schemas
|
||||
schemas = openapi_spec["components"]["schemas"]
|
||||
assert "Calendar" in schemas
|
||||
assert "Event" in schemas
|
||||
assert "EventDateTime" in schemas
|
||||
|
||||
# Check paths
|
||||
paths = openapi_spec["paths"]
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
assert "/calendars" in paths
|
||||
assert "/calendars/{calendarId}/events" in paths
|
||||
|
||||
# Check method details
|
||||
get_events = paths["/calendars/{calendarId}/events"]["get"]
|
||||
assert get_events["operationId"] == "calendar.events.list"
|
||||
|
||||
# Check parameter details
|
||||
param_dict = {p["name"]: p for p in get_events["parameters"]}
|
||||
assert "maxResults" in param_dict
|
||||
max_results = param_dict["maxResults"]
|
||||
assert max_results["in"] == "query"
|
||||
assert max_results["schema"]["type"] == "integer"
|
||||
assert max_results["schema"]["default"] == "250"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conftest_content():
|
||||
"""Returns content for a conftest.py file to help with testing."""
|
||||
return """
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# This file contains fixtures that can be shared across multiple test modules
|
||||
|
||||
@pytest.fixture
|
||||
def mock_google_response():
|
||||
\"\"\"Fixture that provides a mock response from Google's API.\"\"\"
|
||||
return {"key": "value", "items": [{"id": 1}, {"id": 2}]}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_error():
|
||||
\"\"\"Fixture that provides a mock HTTP error.\"\"\"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = 404
|
||||
return HttpError(resp=mock_resp, content=b'Not Found')
|
||||
"""
|
||||
|
||||
|
||||
def test_generate_conftest_example(conftest_content):
|
||||
"""This is a meta-test that demonstrates how to generate a conftest.py file.
|
||||
|
||||
In a real project, you would create a separate conftest.py file.
|
||||
"""
|
||||
# In a real scenario, you would write this to a file named conftest.py
|
||||
# This test just verifies the conftest content is not empty
|
||||
assert len(conftest_content) > 0
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for AutoAuthCredentialExchanger."""
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import pytest
|
||||
|
||||
|
||||
class MockCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Mock credential exchanger for testing."""
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Mock exchange credential method."""
|
||||
return auth_credential
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_exchanger():
|
||||
"""Fixture for creating an AutoAuthCredentialExchanger instance."""
|
||||
return AutoAuthCredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
"""Fixture for creating a mock AuthScheme instance."""
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
return scheme
|
||||
|
||||
|
||||
def test_init_with_custom_exchangers():
|
||||
"""Test initialization with custom exchangers."""
|
||||
custom_exchangers: Dict[str, Type[BaseAuthCredentialExchanger]] = {
|
||||
AuthCredentialTypes.API_KEY: MockCredentialExchanger
|
||||
}
|
||||
|
||||
auto_exchanger = AutoAuthCredentialExchanger(
|
||||
custom_exchangers=custom_exchangers
|
||||
)
|
||||
|
||||
assert (
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY]
|
||||
== MockCredentialExchanger
|
||||
)
|
||||
assert (
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT]
|
||||
== OAuth2CredentialExchanger
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_no_auth_credential(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with no auth_credential."""
|
||||
|
||||
assert auto_exchanger.exchange_credential(auth_scheme, None) is None
|
||||
|
||||
|
||||
def test_exchange_credential_no_exchange(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with NoExchangeCredentialExchanger."""
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == auth_credential
|
||||
|
||||
|
||||
def test_exchange_credential_open_id_connect(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with OpenID Connect scheme."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
)
|
||||
mock_exchanger = MagicMock(spec=OAuth2CredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "exchanged_credential"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "exchanged_credential"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_service_account(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with Service Account scheme."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
)
|
||||
mock_exchanger = MagicMock(spec=ServiceAccountCredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "exchanged_credential_sa"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.SERVICE_ACCOUNT] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "exchanged_credential_sa"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_custom_exchanger(auto_exchanger, auth_scheme):
|
||||
"""Test that exchange_credential calls the correct (custom) exchanger."""
|
||||
# Use a custom exchanger via the initialization
|
||||
mock_exchanger = MagicMock(spec=MockCredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "custom_credential"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "custom_credential"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the BaseAuthCredentialExchanger class."""
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
import pytest
|
||||
|
||||
|
||||
class MockAuthCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
return AuthCredential(token="some-token")
|
||||
|
||||
|
||||
class TestBaseAuthCredentialExchanger:
|
||||
"""Tests for the BaseAuthCredentialExchanger class."""
|
||||
|
||||
@pytest.fixture
|
||||
def base_exchanger(self):
|
||||
return BaseAuthCredentialExchanger()
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme(self):
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
scheme.type = "apiKey"
|
||||
scheme.name = "x-api-key"
|
||||
return scheme
|
||||
|
||||
def test_exchange_credential_not_implemented(
|
||||
self, base_exchanger, auth_scheme
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, token="some-token"
|
||||
)
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
base_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Subclasses must implement exchange_credential." in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_auth_credential_missing_error(self):
|
||||
error_message = "Test missing credential"
|
||||
error = AuthCredentialMissingError(error_message)
|
||||
# assert error.message == error_message
|
||||
assert str(error) == error_message
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for OAuth2CredentialExchanger."""
|
||||
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_exchanger():
|
||||
return OAuth2CredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
openid_config = OpenIdConnectWithConfig(
|
||||
type_=AuthSchemeType.openIdConnect,
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
token_endpoint="https://example.com/token",
|
||||
scopes=["openid", "profile"],
|
||||
)
|
||||
return openid_config
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
# Check that the method does not raise an exception
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_missing_credential(
|
||||
oauth2_exchanger, auth_scheme
|
||||
):
|
||||
# Test case: auth_credential is None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
|
||||
assert "auth_credential is empty" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_invalid_scheme_type(
|
||||
oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
|
||||
):
|
||||
"""Test case: Invalid AuthSchemeType."""
|
||||
# Test case: Invalid AuthSchemeType
|
||||
invalid_scheme = copy.deepcopy(auth_scheme)
|
||||
invalid_scheme.type_ = AuthSchemeType.apiKey
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(
|
||||
invalid_scheme, auth_credential
|
||||
)
|
||||
assert "Invalid security scheme" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_missing_openid_connect(
|
||||
oauth2_exchanger, auth_scheme
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
assert "auth_credential is not configured with oauth2" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_generate_auth_token_success(
|
||||
oauth2_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test case: Successful generation of access token."""
|
||||
# Test case: Successful generation of access token
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
auth_response_uri="https://example.com/callback?code=test_code",
|
||||
token={"access_token": "test_access_token"},
|
||||
),
|
||||
)
|
||||
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
|
||||
|
||||
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert updated_credential.http.scheme == "bearer"
|
||||
assert updated_credential.http.credentials.token == "test_access_token"
|
||||
|
||||
|
||||
def test_exchange_credential_generate_auth_token(
|
||||
oauth2_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test exchange_credential when auth_response_uri is present."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
auth_response_uri="https://example.com/callback?code=test_code",
|
||||
token={"access_token": "test_access_token"},
|
||||
),
|
||||
)
|
||||
|
||||
updated_credential = oauth2_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert updated_credential.http.scheme == "bearer"
|
||||
assert updated_credential.http.credentials.token == "test_access_token"
|
||||
|
||||
|
||||
def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
|
||||
"""Test exchange_credential when auth_credential is missing."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "auth_credential is empty. Please create AuthCredential using" in str(
|
||||
exc_info.value
|
||||
)
|
||||
@@ -0,0 +1,196 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the service account credential exchanger."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_credential import ServiceAccountCredential
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import google.auth
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_exchanger():
|
||||
return ServiceAccountCredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
scheme.type_ = AuthSchemeType.oauth2
|
||||
scheme.description = "Google Service Account"
|
||||
return scheme
|
||||
|
||||
|
||||
def test_exchange_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
|
||||
# Mock the from_service_account_info method
|
||||
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
# Mock the refresh method
|
||||
mock_credentials.refresh = MagicMock()
|
||||
|
||||
# Create a valid AuthCredential with service account info
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_use_default_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials using default credential."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
mock_google_auth_default = MagicMock(
|
||||
return_value=(mock_credentials, "test_project")
|
||||
)
|
||||
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_google_auth_default.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_missing_auth_credential(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing auth credential during exchange."""
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_missing_service_account_info(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing service account info during exchange."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_exchange_failure(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test failure during service account token exchange."""
|
||||
mock_from_service_account_info = MagicMock(
|
||||
side_effect=Exception("Failed to load credentials")
|
||||
)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Failed to exchange service account token" in str(exc_info.value)
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
573
tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
Normal file
573
tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OpenIdConnect
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_credential import ServiceAccountCredential
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import credential_to_param
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import INTERNAL_AUTH_PREFIX
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_url_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_header():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "X-API-Key", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.header
|
||||
assert scheme.name == "X-API-Key"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_query():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "query", "api_key", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.query
|
||||
assert scheme.name == "api_key"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_cookie():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "cookie", "session_id", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.cookie
|
||||
assert scheme.name == "session_id"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_no_credential():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "cookie", "session_id"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert credential is None
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_oauth2_token():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"oauth2Token", "header", "Authorization", "test_token"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_oauth2_no_credential():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"oauth2Token", "header", "Authorization"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert credential is None
|
||||
|
||||
|
||||
def test_service_account_dict_to_scheme_credential():
|
||||
config = {
|
||||
"type": "service_account",
|
||||
"project_id": "project_id",
|
||||
"private_key_id": "private_key_id",
|
||||
"private_key": "private_key",
|
||||
"client_email": "client_email",
|
||||
"client_id": "client_id",
|
||||
"auth_uri": "auth_uri",
|
||||
"token_uri": "token_uri",
|
||||
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
|
||||
"client_x509_cert_url": "client_x509_cert_url",
|
||||
"universe_domain": "universe_domain",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = service_account_dict_to_scheme_credential(config, scopes)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
assert credential.service_account.scopes == scopes
|
||||
assert (
|
||||
credential.service_account.service_account_credential.project_id
|
||||
== "project_id"
|
||||
)
|
||||
|
||||
|
||||
def test_service_account_scheme_credential():
|
||||
config = ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type="service_account",
|
||||
project_id="project_id",
|
||||
private_key_id="private_key_id",
|
||||
private_key="private_key",
|
||||
client_email="client_email",
|
||||
client_id="client_id",
|
||||
auth_uri="auth_uri",
|
||||
token_uri="token_uri",
|
||||
auth_provider_x509_cert_url="auth_provider_x509_cert_url",
|
||||
client_x509_cert_url="client_x509_cert_url",
|
||||
universe_domain="universe_domain",
|
||||
),
|
||||
scopes=["scope1", "scope2"],
|
||||
)
|
||||
|
||||
scheme, credential = service_account_scheme_credential(config)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
assert credential.service_account == config
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"openIdConnectUrl": "openid_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert scheme.authorization_endpoint == "auth_url"
|
||||
assert scheme.token_endpoint == "token_url"
|
||||
assert scheme.scopes == scopes
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_no_openid_url():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert scheme.openIdConnectUrl == ""
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_google_oauth_credential():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"openIdConnectUrl": "openid_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"web": {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_invalid_config():
|
||||
config_dict = {
|
||||
"invalid_field": "value",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid OpenID Connect configuration"):
|
||||
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_missing_credential_fields():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Missing required fields in credential_dict: client_secret",
|
||||
):
|
||||
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential(mock_get):
|
||||
mock_response = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"userinfo_endpoint": "userinfo_url",
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_url_to_scheme_credential(
|
||||
"openid_url", scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert scheme.authorization_endpoint == "auth_url"
|
||||
assert scheme.token_endpoint == "token_url"
|
||||
assert scheme.scopes == scopes
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
mock_get.assert_called_once_with("openid_url", timeout=10)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_no_openid_url(mock_get):
|
||||
mock_response = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"userinfo_endpoint": "userinfo_url",
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_url_to_scheme_credential(
|
||||
"openid_url", scopes, credential_dict
|
||||
)
|
||||
|
||||
assert scheme.openIdConnectUrl == "openid_url"
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_request_exception(mock_get):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Test Error")
|
||||
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to fetch OpenID configuration from openid_url"
|
||||
):
|
||||
openid_url_to_scheme_credential("openid_url", [], credential_dict)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_invalid_json(mock_get):
|
||||
mock_get.return_value.json.side_effect = ValueError("Invalid JSON")
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Invalid JSON response from OpenID configuration endpoint openid_url"
|
||||
),
|
||||
):
|
||||
openid_url_to_scheme_credential("openid_url", [], credential_dict)
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_header():
|
||||
auth_scheme = APIKey(
|
||||
**{"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "X-API-Key"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "X-API-Key": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_query():
|
||||
auth_scheme = APIKey(**{"type": "apiKey", "in": "query", "name": "api_key"})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "api_key"
|
||||
assert param.param_location == "query"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "api_key": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_cookie():
|
||||
auth_scheme = APIKey(
|
||||
**{"type": "apiKey", "in": "cookie", "name": "session_id"}
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "session_id"
|
||||
assert param.param_location == "cookie"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "session_id": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_http_bearer():
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_http_basic_not_supported():
|
||||
auth_scheme = HTTPBase(scheme="basic")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="password"),
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="Basic Authentication is not supported."
|
||||
):
|
||||
credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_credential_to_param_http_invalid_credentials_no_http():
|
||||
auth_scheme = HTTPBase(scheme="basic")
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HTTP auth credentials"):
|
||||
credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_credential_to_param_oauth2():
|
||||
auth_scheme = OAuth2(flows={})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_openid_connect():
|
||||
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_openid_no_credential():
|
||||
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, None)
|
||||
|
||||
assert param == None
|
||||
assert kwargs == None
|
||||
|
||||
|
||||
def test_credential_to_param_oauth2_no_credential():
|
||||
auth_scheme = OAuth2(flows={})
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, None)
|
||||
|
||||
assert param == None
|
||||
assert kwargs == None
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_api_key():
|
||||
data = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.header
|
||||
assert scheme.name == "X-API-Key"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_http_bearer():
|
||||
data = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.scheme == "bearer"
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_http_base():
|
||||
data = {"type": "http", "scheme": "basic"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, HTTPBase)
|
||||
assert scheme.scheme == "basic"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_oauth2():
|
||||
data = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://example.com/auth",
|
||||
"tokenUrl": "https://example.com/token",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, OAuth2)
|
||||
assert hasattr(scheme.flows, "authorizationCode")
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_openid_connect():
|
||||
data = {
|
||||
"type": "openIdConnect",
|
||||
"openIdConnectUrl": (
|
||||
"https://example.com/.well-known/openid-configuration"
|
||||
),
|
||||
}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnect)
|
||||
assert (
|
||||
scheme.openIdConnectUrl
|
||||
== "https://example.com/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_missing_type():
|
||||
data = {"in": "header", "name": "X-API-Key"}
|
||||
with pytest.raises(
|
||||
ValueError, match="Missing 'type' field in security scheme dictionary."
|
||||
):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_invalid_type():
|
||||
data = {"type": "invalid", "in": "header", "name": "X-API-Key"}
|
||||
with pytest.raises(ValueError, match="Invalid security scheme type: invalid"):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_invalid_data():
|
||||
data = {"type": "apiKey", "in": "header"} # Missing 'name'
|
||||
with pytest.raises(ValueError, match="Invalid security scheme data"):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
436
tests/unittests/tools/openapi_tool/common/test_common.py
Normal file
436
tests/unittests/tools/openapi_tool/common/test_common.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from fastapi.openapi.models import Response, Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.common.common import PydocHelper
|
||||
from google.adk.tools.openapi_tool.common.common import rename_python_keywords
|
||||
from google.adk.tools.openapi_tool.common.common import to_snake_case
|
||||
from google.adk.tools.openapi_tool.common.common import TypeHintHelper
|
||||
import pytest
|
||||
|
||||
|
||||
def dict_to_responses(input: Dict[str, Any]) -> Dict[str, Response]:
|
||||
return {k: Response.model_validate(input[k]) for k in input}
|
||||
|
||||
|
||||
class TestToSnakeCase:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_str, expected_output',
|
||||
[
|
||||
('lowerCamelCase', 'lower_camel_case'),
|
||||
('UpperCamelCase', 'upper_camel_case'),
|
||||
('space separated', 'space_separated'),
|
||||
('REST API', 'rest_api'),
|
||||
('Mixed_CASE with_Spaces', 'mixed_case_with_spaces'),
|
||||
('__init__', 'init'),
|
||||
('APIKey', 'api_key'),
|
||||
('SomeLongURL', 'some_long_url'),
|
||||
('CONSTANT_CASE', 'constant_case'),
|
||||
('already_snake_case', 'already_snake_case'),
|
||||
('single', 'single'),
|
||||
('', ''),
|
||||
(' spaced ', 'spaced'),
|
||||
('with123numbers', 'with123numbers'),
|
||||
('With_Mixed_123_and_SPACES', 'with_mixed_123_and_spaces'),
|
||||
('HTMLParser', 'html_parser'),
|
||||
('HTTPResponseCode', 'http_response_code'),
|
||||
('a_b_c', 'a_b_c'),
|
||||
('A_B_C', 'a_b_c'),
|
||||
('fromAtoB', 'from_ato_b'),
|
||||
('XMLHTTPRequest', 'xmlhttp_request'),
|
||||
('_leading', 'leading'),
|
||||
('trailing_', 'trailing'),
|
||||
(' leading_and_trailing_ ', 'leading_and_trailing'),
|
||||
('Multiple___Underscores', 'multiple_underscores'),
|
||||
(' spaces_and___underscores ', 'spaces_and_underscores'),
|
||||
(' _mixed_Case ', 'mixed_case'),
|
||||
('123Start', '123_start'),
|
||||
('End123', 'end123'),
|
||||
('Mid123dle', 'mid123dle'),
|
||||
],
|
||||
)
|
||||
def test_to_snake_case(self, input_str, expected_output):
|
||||
assert to_snake_case(input_str) == expected_output
|
||||
|
||||
|
||||
class TestRenamePythonKeywords:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_str, expected_output',
|
||||
[
|
||||
('in', 'param_in'),
|
||||
('for', 'param_for'),
|
||||
('class', 'param_class'),
|
||||
('normal', 'normal'),
|
||||
('param_if', 'param_if'),
|
||||
('', ''),
|
||||
],
|
||||
)
|
||||
def test_rename_python_keywords(self, input_str, expected_output):
|
||||
assert rename_python_keywords(input_str) == expected_output
|
||||
|
||||
|
||||
class TestApiParameter:
|
||||
|
||||
def test_api_parameter_initialization(self):
|
||||
schema = Schema(type='string', description='A string parameter')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
description='A string description',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.original_name == 'testParam'
|
||||
assert param.param_location == 'query'
|
||||
assert param.param_schema.type == 'string'
|
||||
assert param.param_schema.description == 'A string parameter'
|
||||
assert param.py_name == 'test_param'
|
||||
assert param.type_hint == 'str'
|
||||
assert param.type_value == str
|
||||
assert param.description == 'A string description'
|
||||
|
||||
def test_api_parameter_keyword_rename(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='in',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.py_name == 'param_in'
|
||||
|
||||
def test_api_parameter_custom_py_name(self):
|
||||
schema = Schema(type='integer')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
py_name='custom_name',
|
||||
)
|
||||
assert param.py_name == 'custom_name'
|
||||
|
||||
def test_api_parameter_str_representation(self):
|
||||
schema = Schema(type='number')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert str(param) == 'test_param: float'
|
||||
|
||||
def test_api_parameter_to_arg_string(self):
|
||||
schema = Schema(type='boolean')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.to_arg_string() == 'test_param=test_param'
|
||||
|
||||
def test_api_parameter_to_dict_property(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='path',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.to_dict_property() == '"test_param": test_param'
|
||||
|
||||
def test_api_parameter_model_serializer(self):
|
||||
schema = Schema(type='string', description='test description')
|
||||
param = ApiParameter(
|
||||
original_name='TestParam',
|
||||
param_location='path',
|
||||
param_schema=schema,
|
||||
py_name='test_param_custom',
|
||||
description='test description',
|
||||
)
|
||||
|
||||
serialized_param = param.model_dump(mode='json', exclude_none=True)
|
||||
|
||||
assert serialized_param == {
|
||||
'original_name': 'TestParam',
|
||||
'param_location': 'path',
|
||||
'param_schema': {'type': 'string', 'description': 'test description'},
|
||||
'description': 'test description',
|
||||
'py_name': 'test_param_custom',
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'schema, expected_type_value, expected_type_hint',
|
||||
[
|
||||
({'type': 'integer'}, int, 'int'),
|
||||
({'type': 'number'}, float, 'float'),
|
||||
({'type': 'boolean'}, bool, 'bool'),
|
||||
({'type': 'string'}, str, 'str'),
|
||||
(
|
||||
{'type': 'string', 'format': 'date'},
|
||||
str,
|
||||
'str',
|
||||
),
|
||||
(
|
||||
{'type': 'string', 'format': 'date-time'},
|
||||
str,
|
||||
'str',
|
||||
),
|
||||
(
|
||||
{'type': 'array', 'items': {'type': 'integer'}},
|
||||
List[int],
|
||||
'List[int]',
|
||||
),
|
||||
(
|
||||
{'type': 'array', 'items': {'type': 'string'}},
|
||||
List[str],
|
||||
'List[str]',
|
||||
),
|
||||
(
|
||||
{
|
||||
'type': 'array',
|
||||
'items': {'type': 'object'},
|
||||
},
|
||||
List[Dict[str, Any]],
|
||||
'List[Dict[str, Any]]',
|
||||
),
|
||||
({'type': 'object'}, Dict[str, Any], 'Dict[str, Any]'),
|
||||
({'type': 'unknown'}, Any, 'Any'),
|
||||
({}, Any, 'Any'),
|
||||
],
|
||||
)
|
||||
def test_api_parameter_type_hint_helper(
|
||||
self, schema, expected_type_value, expected_type_hint
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name='test', param_location='query', param_schema=schema
|
||||
)
|
||||
assert param.type_value == expected_type_value
|
||||
assert param.type_hint == expected_type_hint
|
||||
assert (
|
||||
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
|
||||
)
|
||||
|
||||
def test_api_parameter_description(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='param1',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='The description',
|
||||
)
|
||||
assert param.description == 'The description'
|
||||
|
||||
def test_api_parameter_description_use_schema_fallback(self):
|
||||
schema = Schema(type='string', description='The description')
|
||||
param = ApiParameter(
|
||||
original_name='param1',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.description == 'The description'
|
||||
|
||||
|
||||
class TestTypeHintHelper:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'schema, expected_type_value, expected_type_hint',
|
||||
[
|
||||
({'type': 'integer'}, int, 'int'),
|
||||
({'type': 'number'}, float, 'float'),
|
||||
({'type': 'string'}, str, 'str'),
|
||||
(
|
||||
{
|
||||
'type': 'array',
|
||||
'items': {'type': 'string'},
|
||||
},
|
||||
List[str],
|
||||
'List[str]',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_type_value_and_hint(
|
||||
self, schema, expected_type_value, expected_type_hint
|
||||
):
|
||||
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test parameter',
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
|
||||
)
|
||||
|
||||
|
||||
class TestPydocHelper:
|
||||
|
||||
def test_generate_param_doc_simple(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test description',
|
||||
)
|
||||
|
||||
expected_doc = 'test_param (str): Test description'
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_no_description(self):
|
||||
schema = Schema(type='integer')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
expected_doc = 'test_param (int): '
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_object(self):
|
||||
schema = Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': {'type': 'string', 'description': 'Prop1 desc'},
|
||||
'prop2': {'type': 'integer'},
|
||||
},
|
||||
)
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test object parameter',
|
||||
)
|
||||
expected_doc = (
|
||||
'test_param (Dict[str, Any]): Test object parameter Object'
|
||||
' properties:\n prop1 (str): Prop1 desc\n prop2'
|
||||
' (int): \n'
|
||||
)
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_object_no_properties(self):
|
||||
schema = Schema(type='object', description='A test schema')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='The description.',
|
||||
)
|
||||
expected_doc = 'test_param (Dict[str, Any]): The description.'
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_return_doc_simple(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
}
|
||||
expected_doc = 'Returns (str): Successful response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_no_content(self):
|
||||
responses = {'204': {'description': 'No content'}}
|
||||
assert not PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
|
||||
def test_generate_return_doc_object(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful object response',
|
||||
'content': {
|
||||
'application/json': {
|
||||
'schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'prop1': {
|
||||
'type': 'string',
|
||||
'description': 'Prop1 desc',
|
||||
},
|
||||
'prop2': {'type': 'integer'},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return_doc = PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
|
||||
assert 'Returns (Dict[str, Any]): Successful object response' in return_doc
|
||||
assert 'prop1 (str): Prop1 desc' in return_doc
|
||||
assert 'prop2 (int):' in return_doc
|
||||
|
||||
def test_generate_return_doc_multiple_success(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
expected_doc = 'Returns (str): Successful response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_2xx_smallest_status_code_response(self):
|
||||
responses = {
|
||||
'201': {
|
||||
'description': '201 response',
|
||||
'content': {'application/json': {'schema': {'type': 'integer'}}},
|
||||
},
|
||||
'200': {
|
||||
'description': '200 response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
|
||||
expected_doc = 'Returns (str): 200 response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_contentful_response(self):
|
||||
responses = {
|
||||
'200': {'description': 'No content response'},
|
||||
'201': {
|
||||
'description': '201 response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
expected_doc = 'Returns (str): 201 response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,628 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||
import pytest
|
||||
|
||||
|
||||
def create_minimal_openapi_spec() -> Dict[str, Any]:
|
||||
"""Creates a minimal valid OpenAPI spec."""
|
||||
return {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Minimal API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"summary": "Test GET endpoint",
|
||||
"operationId": "testGet",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {"schema": {"type": "string"}}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec_generator():
|
||||
"""Fixture for creating an OperationGenerator instance."""
|
||||
return OpenApiSpecParser()
|
||||
|
||||
|
||||
def test_parse_minimal_spec(openapi_spec_generator):
|
||||
"""Test parsing a minimal OpenAPI specification."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.name == "test_get"
|
||||
assert op.endpoint.path == "/test"
|
||||
assert op.endpoint.method == "get"
|
||||
assert op.return_value.type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_no_operation_id(openapi_spec_generator):
|
||||
"""Test parsing a spec where operationId is missing (auto-generation)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
# Check if operationId is auto generated based on path and method.
|
||||
assert parsed_operations[0].name == "test_get"
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_methods(openapi_spec_generator):
|
||||
"""Test parsing a spec with multiple HTTP methods for the same path."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Test POST endpoint",
|
||||
"operationId": "testPost",
|
||||
"responses": {"200": {"description": "Successful response"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
operation_names = {op.name for op in parsed_operations}
|
||||
|
||||
assert len(parsed_operations) == 2
|
||||
assert "test_get" in operation_names
|
||||
assert "test_post" in operation_names
|
||||
|
||||
|
||||
def test_parse_spec_with_parameters(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["parameters"] = [
|
||||
{"name": "param1", "in": "query", "schema": {"type": "string"}},
|
||||
{"name": "param2", "in": "header", "schema": {"type": "integer"}},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations[0].parameters) == 2
|
||||
assert parsed_operations[0].parameters[0].original_name == "param1"
|
||||
assert parsed_operations[0].parameters[0].param_location == "query"
|
||||
assert parsed_operations[0].parameters[1].original_name == "param2"
|
||||
assert parsed_operations[0].parameters[1].param_location == "header"
|
||||
|
||||
|
||||
def test_parse_spec_with_request_body(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Endpoint with request body",
|
||||
"operationId": "testPostWithBody",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
post_operations = [
|
||||
op for op in parsed_operations if op.endpoint.method == "post"
|
||||
]
|
||||
op = post_operations[0]
|
||||
|
||||
assert len(post_operations) == 1
|
||||
assert op.name == "test_post_with_body"
|
||||
assert len(op.parameters) == 1
|
||||
assert op.parameters[0].original_name == "name"
|
||||
assert op.parameters[0].type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_reference(openapi_spec_generator):
|
||||
"""Test parsing a specification with $ref."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "API with Refs", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test_ref": {
|
||||
"get": {
|
||||
"summary": "Endpoint with ref",
|
||||
"operationId": "testGetRef",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/MySchema"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"MySchema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
|
||||
|
||||
def test_parse_spec_with_circular_reference(openapi_spec_generator):
|
||||
"""Test correct handling of circular $ref (important!)."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Circular Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/circular": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/A"}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"A": {
|
||||
"type": "object",
|
||||
"properties": {"b": {"$ref": "#/components/schemas/B"}},
|
||||
},
|
||||
"B": {
|
||||
"type": "object",
|
||||
"properties": {"a": {"$ref": "#/components/schemas/A"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
|
||||
op = parsed_operations[0]
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
assert op.return_value.type_hint == "Dict[str, Any]"
|
||||
|
||||
|
||||
def test_parse_no_paths(openapi_spec_generator):
|
||||
"""Test with a spec that has no paths defined."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "No Paths API", "version": "1.0.0"},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 0 # Should be empty
|
||||
|
||||
|
||||
def test_parse_empty_path_item(openapi_spec_generator):
|
||||
"""Test a path item that is present but empty."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Empty Path Item API", "version": "1.0.0"},
|
||||
"paths": {"/empty": None},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 0
|
||||
|
||||
|
||||
def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a global security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["security"] = [{"api_key": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {
|
||||
"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "apiKey"
|
||||
|
||||
|
||||
def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a local (operation-level) security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "http"
|
||||
assert op.auth_scheme.scheme == "bearer"
|
||||
|
||||
|
||||
def test_parse_spec_with_servers(openapi_spec_generator):
|
||||
"""Test parsing with server URLs."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["servers"] = [
|
||||
{"url": "https://api.example.com"},
|
||||
{"url": "http://localhost:8000"},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
|
||||
|
||||
|
||||
def test_parse_spec_with_no_servers(openapi_spec_generator):
|
||||
"""Test with no servers defined (should default to empty string)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
if "servers" in openapi_spec:
|
||||
del openapi_spec["servers"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
expected_description = "This is a test description."
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == expected_description
|
||||
|
||||
|
||||
def test_parse_spec_with_empty_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = ""
|
||||
openapi_spec["paths"]["/test"]["get"]["summary"] = ""
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_no_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
# delete description
|
||||
if "description" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["description"]
|
||||
if "summary" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["summary"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert (
|
||||
parsed_operations[0].description == ""
|
||||
) # it should be initialized with empty string
|
||||
|
||||
|
||||
def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
|
||||
"""Test that passing a non-dict object to parse raises TypeError"""
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse(123) # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse("openapi_spec") # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse([]) # type: ignore
|
||||
|
||||
|
||||
def test_parse_external_ref_raises_error(openapi_spec_generator):
|
||||
"""Check that external references (not starting with #) raise ValueError."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "External Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/external": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"external_file.json#/components/schemas/ExternalSchema"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
|
||||
"""Test specs with multiple paths, request/response bodies using deep refs."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/path1": {
|
||||
"post": {
|
||||
"operationId": "postPath1",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request1"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response1"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/path2": {
|
||||
"put": {
|
||||
"operationId": "putPath2",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request2"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"get": {
|
||||
"operationId": "getPath2",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Request1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Request2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Level1_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_1_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_1"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level1_2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_2_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_2"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level2_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
|
||||
},
|
||||
},
|
||||
"Level2_2": {
|
||||
"type": "object",
|
||||
"properties": {"level2_2_prop1": {"type": "string"}},
|
||||
},
|
||||
"Level3": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 3
|
||||
|
||||
# Verify Path 1
|
||||
path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
|
||||
assert len(path1_ops) == 1
|
||||
path1_op = path1_ops[0]
|
||||
assert path1_op.name == "post_path1"
|
||||
|
||||
assert len(path1_op.parameters) == 1
|
||||
assert path1_op.parameters[0].original_name == "req1_prop1"
|
||||
assert (
|
||||
path1_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path1_op.return_value.param_schema.properties["res1_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
# Verify Path 2
|
||||
path2_ops = [
|
||||
op
|
||||
for op in parsed_operations
|
||||
if op.endpoint.path == "/path2" and op.name == "put_path2"
|
||||
]
|
||||
path2_op = path2_ops[0]
|
||||
assert path2_op is not None
|
||||
assert len(path2_op.parameters) == 1
|
||||
assert path2_op.parameters[0].original_name == "req2_prop1"
|
||||
assert (
|
||||
path2_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path2_op.return_value.param_schema.properties["res2_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
|
||||
"""Test handling of duplicate parameter names (one in query, one in body).
|
||||
|
||||
The expected behavior is that both parameters should be captured but with
|
||||
different suffix, and
|
||||
their `original_name` attributes should reflect their origin (query or body).
|
||||
"""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/duplicate": {
|
||||
"post": {
|
||||
"operationId": "createWithDuplicate",
|
||||
"parameters": [{
|
||||
"name": "name",
|
||||
"in": "query",
|
||||
"schema": {"type": "string"},
|
||||
}],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "integer"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
op = parsed_operations[0]
|
||||
assert op.name == "create_with_duplicate"
|
||||
assert len(op.parameters) == 2
|
||||
|
||||
query_param = None
|
||||
body_param = None
|
||||
for param in op.parameters:
|
||||
if param.param_location == "query" and param.original_name == "name":
|
||||
query_param = param
|
||||
elif param.param_location == "body" and param.original_name == "name":
|
||||
body_param = param
|
||||
|
||||
assert query_param is not None
|
||||
assert query_param.original_name == "name"
|
||||
assert query_param.py_name == "name"
|
||||
|
||||
assert body_param is not None
|
||||
assert body_param.original_name == "name"
|
||||
assert body_param.py_name == "name_0"
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import ParameterInType
|
||||
from fastapi.openapi.models import SecuritySchemeType
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
def load_spec(file_path: str) -> Dict:
|
||||
"""Loads the OpenAPI specification from a YAML file."""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec() -> Dict:
|
||||
"""Fixture to load the OpenAPI specification."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Join the directory path with the filename
|
||||
yaml_path = os.path.join(current_dir, "test.yaml")
|
||||
return load_spec(yaml_path)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a dictionary."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a YAML string."""
|
||||
spec_str = yaml.dump(openapi_spec)
|
||||
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for an existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool_name = "calendar_calendars_insert" # Example operationId from the spec
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Creates a secondary calendar."
|
||||
assert tool.endpoint.method == "post"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.insert"
|
||||
assert tool.operation.description == "Creates a secondary calendar."
|
||||
assert isinstance(
|
||||
tool.operation.requestBody.content["application/json"], MediaType
|
||||
)
|
||||
assert len(tool.operation.responses) == 1
|
||||
response = tool.operation.responses["200"]
|
||||
assert response.description == "Successful response"
|
||||
assert isinstance(response.content["application/json"], MediaType)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
tool_name = "calendar_calendars_get"
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Returns metadata for a calendar."
|
||||
assert tool.endpoint.method == "get"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars/{calendarId}"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.get"
|
||||
assert tool.operation.description == "Returns metadata for a calendar."
|
||||
assert len(tool.operation.parameters) == 1
|
||||
assert tool.operation.parameters[0].name == "calendarId"
|
||||
assert tool.operation.parameters[0].in_ == ParameterInType.path
|
||||
assert tool.operation.parameters[0].required is True
|
||||
assert tool.operation.parameters[0].schema_.type == "string"
|
||||
assert (
|
||||
tool.operation.parameters[0].description
|
||||
== "Calendar identifier. To retrieve calendar IDs call the"
|
||||
" calendarList.list method. If you want to access the primary calendar"
|
||||
' of the currently logged in user, use the "primary" keyword.'
|
||||
)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for a non-existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool = toolset.get_tool("non_existent_tool")
|
||||
assert tool is None
|
||||
|
||||
|
||||
def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
|
||||
"""Test configuring auth during initialization."""
|
||||
|
||||
auth_scheme = APIKey(**{
|
||||
"in": APIKeyIn.header, # Use alias name in dict
|
||||
"name": "api_key",
|
||||
"type": SecuritySchemeType.http,
|
||||
})
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
toolset = OpenAPIToolset(
|
||||
spec_dict=openapi_spec,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
for tool in toolset.tools:
|
||||
assert tool.auth_scheme == auth_scheme
|
||||
assert tool.auth_credential == auth_credential
|
||||
@@ -0,0 +1,406 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation() -> Operation:
|
||||
"""Fixture to provide a sample OpenAPI Operation object."""
|
||||
return Operation(
|
||||
operationId='test_operation',
|
||||
summary='Test Summary',
|
||||
description='Test Description',
|
||||
parameters=[
|
||||
Parameter(**{
|
||||
'name': 'param1',
|
||||
'in': 'query',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 1',
|
||||
}),
|
||||
Parameter(**{
|
||||
'name': 'param2',
|
||||
'in': 'header',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 2',
|
||||
}),
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': Schema(
|
||||
type='string', description='Property 1'
|
||||
),
|
||||
'prop2': Schema(
|
||||
type='integer', description='Property 2'
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
},
|
||||
description='Request body description',
|
||||
),
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='string'))
|
||||
},
|
||||
),
|
||||
'400': Response(description='Client Error'),
|
||||
},
|
||||
security=[{'oauth2': ['resource: read', 'resource: write']}],
|
||||
)
|
||||
|
||||
|
||||
def test_operation_parser_initialization(sample_operation):
|
||||
"""Test initialization of OperationParser."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.operation == sample_operation
|
||||
assert len(parser.params) == 4 # 2 params + 2 request body props
|
||||
assert parser.return_value is not None
|
||||
|
||||
|
||||
def test_process_operation_parameters(sample_operation):
|
||||
"""Test _process_operation_parameters method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_operation_parameters()
|
||||
assert len(parser.params) == 2
|
||||
assert parser.params[0].original_name == 'param1'
|
||||
assert parser.params[0].param_location == 'query'
|
||||
assert parser.params[1].original_name == 'param2'
|
||||
assert parser.params[1].param_location == 'header'
|
||||
|
||||
|
||||
def test_process_request_body(sample_operation):
|
||||
"""Test _process_request_body method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 2 # 2 properties in request body
|
||||
assert parser.params[0].original_name == 'prop1'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
assert parser.params[1].original_name == 'prop2'
|
||||
assert parser.params[1].param_location == 'body'
|
||||
|
||||
|
||||
def test_process_request_body_array():
|
||||
"""Test _process_request_body method with array schema."""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='array',
|
||||
items=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'item_prop1': Schema(
|
||||
type='string', description='Item Property 1'
|
||||
),
|
||||
'item_prop2': Schema(
|
||||
type='integer', description='Item Property 2'
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'array'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
# Check that schema is correctly propagated and is a dictionary
|
||||
assert parser.params[0].param_schema.type == 'array'
|
||||
assert parser.params[0].param_schema.items.type == 'object'
|
||||
assert 'item_prop1' in parser.params[0].param_schema.items.properties
|
||||
assert 'item_prop2' in parser.params[0].param_schema.items.properties
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop1'].description
|
||||
== 'Item Property 1'
|
||||
)
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop2'].description
|
||||
== 'Item Property 2'
|
||||
)
|
||||
|
||||
|
||||
def test_process_request_body_no_name():
|
||||
"""Test _process_request_body with a schema that has no properties (unnamed)"""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={'application/json': MediaType(schema=Schema(type='string'))}
|
||||
)
|
||||
)
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == '' # No name
|
||||
assert parser.params[0].param_location == 'body'
|
||||
|
||||
|
||||
def test_dedupe_param_names(sample_operation):
|
||||
"""Test _dedupe_param_names method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
# Add duplicate named parameters.
|
||||
parser.params = [
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
]
|
||||
parser._dedupe_param_names()
|
||||
assert parser.params[0].py_name == 'test'
|
||||
assert parser.params[1].py_name == 'test_0'
|
||||
assert parser.params[2].py_name == 'test_1'
|
||||
|
||||
|
||||
def test_process_return_value(sample_operation):
|
||||
"""Test _process_return_value method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
|
||||
|
||||
def test_process_return_value_no_2xx(sample_operation):
|
||||
"""Tests _process_return_value when no 2xx response exists."""
|
||||
operation_no_2xx = Operation(
|
||||
responses={'400': Response(description='Client Error')}
|
||||
)
|
||||
parser = OperationParser(operation_no_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_multiple_2xx(sample_operation):
|
||||
"""Tests _process_return_value when multiple 2xx responses exist."""
|
||||
operation_multi_2xx = Operation(
|
||||
responses={
|
||||
'201': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='integer'))
|
||||
},
|
||||
),
|
||||
'202': Response(
|
||||
description='Success',
|
||||
content={'text/plain': MediaType(schema=Schema(type='string'))},
|
||||
),
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/pdf': MediaType(schema=Schema(type='boolean'))
|
||||
},
|
||||
),
|
||||
'400': Response(
|
||||
description='Failure',
|
||||
content={
|
||||
'application/xml': MediaType(schema=Schema(type='object'))
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
parser = OperationParser(operation_multi_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
|
||||
assert parser.return_value is not None
|
||||
# Take the content type of the 200 response since it's the smallest response
|
||||
# code
|
||||
assert parser.return_value.param_schema.type == 'boolean'
|
||||
|
||||
|
||||
def test_process_return_value_no_content(sample_operation):
|
||||
"""Test when 2xx response has no content"""
|
||||
operation_no_content = Operation(
|
||||
responses={'200': Response(description='Success', content={})}
|
||||
)
|
||||
parser = OperationParser(operation_no_content, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_no_schema(sample_operation):
|
||||
"""Tests when the 2xx response's content has no schema."""
|
||||
operation_no_schema = Operation(
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={'application/json': MediaType(schema=None)},
|
||||
)
|
||||
}
|
||||
)
|
||||
parser = OperationParser(operation_no_schema, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_get_function_name(sample_operation):
|
||||
"""Test get_function_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_function_name() == 'test_operation'
|
||||
|
||||
|
||||
def test_get_function_name_missing_id():
|
||||
"""Tests get_function_name when operationId is missing"""
|
||||
operation = Operation() # No ID
|
||||
parser = OperationParser(operation)
|
||||
with pytest.raises(ValueError, match='Operation ID is missing'):
|
||||
parser.get_function_name()
|
||||
|
||||
|
||||
def test_get_return_type_hint(sample_operation):
|
||||
"""Test get_return_type_hint method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_hint() == 'str'
|
||||
|
||||
|
||||
def test_get_return_type_value(sample_operation):
|
||||
"""Test get_return_type_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_value() == str
|
||||
|
||||
|
||||
def test_get_parameters(sample_operation):
|
||||
"""Test get_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
params = parser.get_parameters()
|
||||
assert len(params) == 4 # Correct count after processing
|
||||
assert all(isinstance(p, ApiParameter) for p in params)
|
||||
|
||||
|
||||
def test_get_return_value(sample_operation):
|
||||
"""Test get_return_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
return_value = parser.get_return_value()
|
||||
assert isinstance(return_value, ApiParameter)
|
||||
|
||||
|
||||
def test_get_auth_scheme_name(sample_operation):
|
||||
"""Test get_auth_scheme_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_auth_scheme_name() == 'oauth2'
|
||||
|
||||
|
||||
def test_get_auth_scheme_name_no_security():
|
||||
"""Test get_auth_scheme_name when no security is present."""
|
||||
operation = Operation(responses={})
|
||||
parser = OperationParser(operation)
|
||||
assert parser.get_auth_scheme_name() == ''
|
||||
|
||||
|
||||
def test_get_pydoc_string(sample_operation):
|
||||
"""Test get_pydoc_string method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
pydoc_string = parser.get_pydoc_string()
|
||||
assert 'Test Summary' in pydoc_string
|
||||
assert 'Args:' in pydoc_string
|
||||
assert 'param1 (str): Parameter 1' in pydoc_string
|
||||
assert 'prop1 (str): Property 1' in pydoc_string
|
||||
assert 'Returns (str):' in pydoc_string
|
||||
assert 'Success' in pydoc_string
|
||||
|
||||
|
||||
def test_get_json_schema(sample_operation):
|
||||
"""Test get_json_schema method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
json_schema = parser.get_json_schema()
|
||||
assert json_schema['title'] == 'test_operation_Arguments'
|
||||
assert json_schema['type'] == 'object'
|
||||
assert 'param1' in json_schema['properties']
|
||||
assert 'prop1' in json_schema['properties']
|
||||
assert 'param1' in json_schema['required']
|
||||
assert 'prop1' in json_schema['required']
|
||||
|
||||
|
||||
def test_get_signature_parameters(sample_operation):
|
||||
"""Test get_signature_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
signature_params = parser.get_signature_parameters()
|
||||
assert len(signature_params) == 4
|
||||
assert signature_params[0].name == 'param1'
|
||||
assert signature_params[0].annotation == str
|
||||
assert signature_params[2].name == 'prop1'
|
||||
assert signature_params[2].annotation == str
|
||||
|
||||
|
||||
def test_get_annotations(sample_operation):
|
||||
"""Test get_annotations method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
annotations = parser.get_annotations()
|
||||
assert len(annotations) == 5 # 4 parameters + return
|
||||
assert annotations['param1'] == str
|
||||
assert annotations['prop1'] == str
|
||||
assert annotations['return'] == str
|
||||
|
||||
|
||||
def test_load():
|
||||
"""Test the load classmethod."""
|
||||
operation = Operation(operationId='my_op') # Minimal operation
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name='p1',
|
||||
param_location='',
|
||||
param_schema={'type': 'integer'},
|
||||
)
|
||||
]
|
||||
return_value = ApiParameter(
|
||||
original_name='', param_location='', param_schema={'type': 'string'}
|
||||
)
|
||||
|
||||
parser = OperationParser.load(operation, params, return_value)
|
||||
|
||||
assert isinstance(parser, OperationParser)
|
||||
assert parser.operation == operation
|
||||
assert parser.params == params
|
||||
assert parser.return_value == return_value
|
||||
assert (
|
||||
parser.get_function_name() == 'my_op'
|
||||
) # Check that the operation is loaded
|
||||
|
||||
|
||||
def test_operation_parser_with_dict():
|
||||
"""Test initialization of OperationParser with a dictionary."""
|
||||
operation_dict = {
|
||||
'operationId': 'test_dict_operation',
|
||||
'parameters': [
|
||||
{'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
|
||||
],
|
||||
'responses': {
|
||||
'200': {
|
||||
'description': 'Dict Success',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
parser = OperationParser(operation_dict)
|
||||
assert parser.operation.operationId == 'test_dict_operation'
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'dict_param'
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
@@ -0,0 +1,966 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter as OpenAPIParameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Schema as OpenAPISchema
|
||||
from google.adk.sessions.state import State
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
from google.genai.types import Type
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRestApiTool:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_context = MagicMock(spec=ToolContext)
|
||||
mock_context.state = State({}, {})
|
||||
mock_context.get_auth_response.return_value = {}
|
||||
mock_context.request_credential.return_value = {}
|
||||
return mock_context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_operation_parser(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_parser = MagicMock(spec=OperationParser)
|
||||
mock_parser.get_function_name.return_value = "mock_function_name"
|
||||
mock_parser.get_json_schema.return_value = {}
|
||||
mock_parser.get_parameters.return_value = []
|
||||
mock_parser.get_return_type_hint.return_value = "str"
|
||||
mock_parser.get_pydoc_string.return_value = "Mock docstring"
|
||||
mock_parser.get_signature_parameters.return_value = []
|
||||
mock_parser.get_return_type_value.return_value = str
|
||||
mock_parser.get_annotations.return_value = {}
|
||||
return mock_parser
|
||||
|
||||
@pytest.fixture
|
||||
def sample_endpiont(self):
|
||||
return OperationEndpoint(
|
||||
base_url="https://example.com", path="/test", method="GET"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation(self):
|
||||
return Operation(
|
||||
operationId="testOperation",
|
||||
description="Test operation",
|
||||
parameters=[],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"testBodyParam": OpenAPISchema(type="string")
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_api_parameters(self):
|
||||
return [
|
||||
ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="test_body_param",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_return_parameter(self):
|
||||
return ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_scheme(self):
|
||||
scheme, _ = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return scheme
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_credential(self):
|
||||
_, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return credential
|
||||
|
||||
def test_init(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test Tool"
|
||||
assert tool.endpoint == sample_endpiont
|
||||
assert tool.operation == sample_operation
|
||||
assert tool.auth_credential == sample_auth_credential
|
||||
assert tool.auth_scheme == sample_auth_scheme
|
||||
assert tool.credential_exchanger is not None
|
||||
|
||||
def test_from_parsed_operation_str(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_api_parameters,
|
||||
sample_return_parameter,
|
||||
sample_operation,
|
||||
):
|
||||
parsed_operation_str = json.dumps({
|
||||
"name": "test_operation",
|
||||
"description": "Test Description",
|
||||
"endpoint": sample_endpiont.model_dump(),
|
||||
"operation": sample_operation.model_dump(),
|
||||
"auth_scheme": None,
|
||||
"auth_credential": None,
|
||||
"parameters": [p.model_dump() for p in sample_api_parameters],
|
||||
"return_value": sample_return_parameter.model_dump(),
|
||||
})
|
||||
|
||||
tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
|
||||
assert tool.name == "test_operation"
|
||||
|
||||
def test_get_declaration(
|
||||
self, sample_endpiont, sample_operation, mock_operation_parser
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test description",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
should_parse_operation=False,
|
||||
)
|
||||
tool._operation_parser = mock_operation_parser
|
||||
|
||||
declaration = tool._get_declaration()
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
assert declaration.name == "test_tool"
|
||||
assert declaration.description == "Test description"
|
||||
assert isinstance(declaration.parameters, Schema)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_success(
|
||||
self,
|
||||
mock_request,
|
||||
mock_tool_context,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tool.call(args={}, tool_context=mock_tool_context)
|
||||
|
||||
# Check the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_auth_pending(
|
||||
self,
|
||||
mock_request,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
with patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||
) as mock_from_tool_context:
|
||||
mock_tool_auth_handler_instance = MagicMock()
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"pending"
|
||||
)
|
||||
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||
|
||||
response = tool.call(args={}, tool_context=None)
|
||||
assert response == {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
def test_prepare_request_params_query_body(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Create a mock Operation object
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
parameters=[
|
||||
OpenAPIParameter(**{
|
||||
"name": "testQueryParam",
|
||||
"in": "query",
|
||||
"schema": OpenAPISchema(type="string"),
|
||||
})
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"param1": OpenAPISchema(type="string"),
|
||||
"param2": OpenAPISchema(type="integer"),
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param1",
|
||||
py_name="param1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="param2",
|
||||
py_name="param2",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="integer"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="testQueryParam",
|
||||
py_name="test_query_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
]
|
||||
kwargs = {
|
||||
"param1": "value1",
|
||||
"param2": 123,
|
||||
"test_query_param": "query_value",
|
||||
}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
assert request_params["method"] == "get"
|
||||
assert request_params["url"] == "https://example.com/test"
|
||||
assert request_params["json"] == {"param1": "value1", "param2": 123}
|
||||
assert request_params["params"] == {"testQueryParam": "query_value"}
|
||||
|
||||
def test_prepare_request_params_array(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="array", # Match the parameter name
|
||||
py_name="array",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
),
|
||||
)
|
||||
]
|
||||
kwargs = {"array": ["item1", "item2"]}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["json"] == ["item1", "item2"]
|
||||
|
||||
def test_prepare_request_params_string(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string"))
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input_string",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input_string": "test_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == "test_value"
|
||||
assert request_params["headers"]["Content-Type"] == "text/plain"
|
||||
|
||||
def test_prepare_request_params_form_data(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/x-www-form-urlencoded": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={"key1": OpenAPISchema(type="string")},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="key1",
|
||||
py_name="key1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"key1": "value1"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == {"key1": "value1"}
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"]
|
||||
== "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_multipart(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"multipart/form-data": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"file1": OpenAPISchema(
|
||||
type="string", format="binary"
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="file1",
|
||||
py_name="file1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"file1": b"file_content"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["files"] == {"file1": b"file_content"}
|
||||
assert request_params["headers"]["Content-Type"] == "multipart/form-data"
|
||||
|
||||
def test_prepare_request_params_octet_stream(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/octet-stream": MediaType(
|
||||
schema=OpenAPISchema(type="string", format="binary")
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="data",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"data": b"binary_data"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == b"binary_data"
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"] == "application/octet-stream"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_path_param(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(operationId="test_op")
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="user_id",
|
||||
py_name="user_id",
|
||||
param_location="path",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"user_id": "123"}
|
||||
endpoint_with_path = OperationEndpoint(
|
||||
base_url="https://example.com", path="/test/{user_id}", method="get"
|
||||
)
|
||||
tool.endpoint = endpoint_with_path
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert (
|
||||
request_params["url"] == "https://example.com/test/123"
|
||||
) # Path param replaced
|
||||
|
||||
def test_prepare_request_params_header_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="X-Custom-Header",
|
||||
py_name="x_custom_header",
|
||||
param_location="header",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"x_custom_header": "header_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["X-Custom-Header"] == "header_value"
|
||||
|
||||
def test_prepare_request_params_cookie_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="session_id",
|
||||
py_name="session_id",
|
||||
param_location="cookie",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"session_id": "cookie_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["cookies"]["session_id"] == "cookie_value"
|
||||
|
||||
def test_prepare_request_params_multiple_mime_types(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Test what happens when multiple mime types are specified. It should take
|
||||
# the first one.
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(type="string")
|
||||
),
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string")),
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input": "some_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
def test_prepare_request_params_unknown_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="known_param",
|
||||
py_name="known_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"known_param": "value", "unknown_param": "unknown"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Make sure unknown parameters are ignored and do not raise errors.
|
||||
assert "unknown_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_base_url_handling(
|
||||
self, sample_auth_credential, sample_auth_scheme, sample_operation
|
||||
):
|
||||
# No base_url provided, should use path as is
|
||||
tool_no_base = RestApiTool(
|
||||
name="test_tool_no_base",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = []
|
||||
kwargs = {}
|
||||
|
||||
request_params_no_base = tool_no_base._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_no_base["url"] == "/no_base"
|
||||
|
||||
tool_trailing_slash = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(
|
||||
base_url="https://example.com/", path="/trailing", method="get"
|
||||
),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
request_params_trailing = tool_trailing_slash._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_trailing["url"] == "https://example.com/trailing"
|
||||
|
||||
def test_prepare_request_params_no_unrecognized_query_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="unrecognized_param",
|
||||
py_name="unrecognized_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"unrecognized_param": None} # Explicitly passing None
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Query param not in sample_operation. It should be ignored.
|
||||
assert "unrecognized_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_no_credential(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=None,
|
||||
auth_scheme=None,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param_name",
|
||||
py_name="param_name",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"param_name": "aaa", "empty_param": ""}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert "param_name" in request_params["params"]
|
||||
assert "empty_param" not in request_params["params"]
|
||||
|
||||
|
||||
class TestToGeminiSchema:
|
||||
|
||||
def test_to_gemini_schema_none(self):
|
||||
assert to_gemini_schema(None) is None
|
||||
|
||||
def test_to_gemini_schema_not_dict(self):
|
||||
with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
|
||||
to_gemini_schema("not a dict")
|
||||
|
||||
def test_to_gemini_schema_empty_dict(self):
|
||||
result = to_gemini_schema({})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_dict_with_only_object_type(self):
|
||||
result = to_gemini_schema({"type": "object"})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_basic_types(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"is_active": {"type": "boolean"},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert isinstance(gemini_schema, Schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["age"].type == Type.INTEGER
|
||||
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
|
||||
|
||||
def test_to_gemini_schema_nested_objects(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["address"].type == Type.OBJECT
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["street"].type
|
||||
== Type.STRING
|
||||
)
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["city"].type
|
||||
== Type.STRING
|
||||
)
|
||||
|
||||
def test_to_gemini_schema_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.ARRAY
|
||||
assert gemini_schema.items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_nested_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.items.properties["name"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_any_of(self):
|
||||
openapi_schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert len(gemini_schema.any_of) == 2
|
||||
assert gemini_schema.any_of[0].type == Type.STRING
|
||||
assert gemini_schema.any_of[1].type == Type.INTEGER
|
||||
|
||||
def test_to_gemini_schema_general_list(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"properties": {
|
||||
"list_field": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["list_field"].type == Type.ARRAY
|
||||
assert gemini_schema.properties["list_field"].items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_enum(self):
|
||||
openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.enum == ["a", "b", "c"]
|
||||
|
||||
def test_to_gemini_schema_required(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.required == ["name"]
|
||||
|
||||
def test_to_gemini_schema_nested_dict(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {"metadata": {"key1": "value1", "key2": 123}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
|
||||
assert isinstance(gemini_schema.properties["metadata"], Schema)
|
||||
assert (
|
||||
gemini_schema.properties["metadata"].type == Type.OBJECT
|
||||
) # add object type by default
|
||||
assert gemini_schema.properties["metadata"].properties == {
|
||||
"dummy_DO_NOT_GENERATE": Schema(type="string")
|
||||
}
|
||||
|
||||
def test_to_gemini_schema_ignore_title_default_format(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"title": "Test Title",
|
||||
"default": "default_value",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
|
||||
assert gemini_schema.title is None
|
||||
assert gemini_schema.default is None
|
||||
assert gemini_schema.format is None
|
||||
|
||||
def test_to_gemini_schema_property_ordering(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"propertyOrdering": ["name", "age"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.property_ordering == ["name", "age"]
|
||||
|
||||
def test_to_gemini_schema_converts_property_dict(self):
|
||||
openapi_schema = {
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "The property key"},
|
||||
"value": {"type": "string", "description": "The property value"},
|
||||
},
|
||||
"type": "object",
|
||||
"description": "A single property entry in the Properties message.",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["value"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_remove_unrecognized_fields(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"description": "A single date string.",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.STRING
|
||||
assert not gemini_schema.format
|
||||
|
||||
|
||||
def test_snake_to_lower_camel():
|
||||
assert snake_to_lower_camel("single") == "single"
|
||||
assert snake_to_lower_camel("two_words") == "twoWords"
|
||||
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
|
||||
assert not snake_to_lower_camel("")
|
||||
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
|
||||
# Helper function to create a mock ToolContext
|
||||
def create_mock_tool_context():
|
||||
return ToolContext(
|
||||
function_call_id='test-fc-id',
|
||||
invocation_context=InvocationContext(
|
||||
agent=LlmAgent(name='test'),
|
||||
session=Session(app_name='test', user_id='123', id='123'),
|
||||
invocation_id='123',
|
||||
session_service=InMemorySessionService(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Test cases for OpenID Connect
|
||||
class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
|
||||
|
||||
def __init__(
|
||||
self, expected_scheme, expected_credential, expected_access_token
|
||||
):
|
||||
self.expected_scheme = expected_scheme
|
||||
self.expected_credential = expected_credential
|
||||
self.expected_access_token = expected_access_token
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
if auth_credential.oauth2 and (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
or auth_credential.oauth2.auth_code
|
||||
):
|
||||
auth_code = (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
if auth_credential.oauth2.auth_response_uri
|
||||
else auth_credential.oauth2.auth_code
|
||||
)
|
||||
# Simulate the token exchange
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme='bearer',
|
||||
credentials=HttpCredentials(
|
||||
token=auth_code + self.expected_access_token
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
# simulate the case of getting auth_uri
|
||||
return None
|
||||
|
||||
|
||||
def get_mock_openid_scheme_credential():
|
||||
config_dict = {
|
||||
'authorization_endpoint': 'test.com',
|
||||
'token_endpoint': 'test.com',
|
||||
}
|
||||
scopes = ['test_scope']
|
||||
credential_dict = {
|
||||
'client_id': '123',
|
||||
'client_secret': '456',
|
||||
'redirect_uri': 'test.com',
|
||||
}
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
# Fixture for the OpenID Connect security scheme
|
||||
@pytest.fixture
|
||||
def openid_connect_scheme():
|
||||
scheme, _ = get_mock_openid_scheme_credential()
|
||||
return scheme
|
||||
|
||||
|
||||
# Fixture for a base OpenID Connect credential
|
||||
@pytest.fixture
|
||||
def openid_connect_credential():
|
||||
_, credential = get_mock_openid_scheme_credential()
|
||||
return credential
|
||||
|
||||
|
||||
def test_openid_connect_no_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
# Setup Mock exchanger
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme, openid_connect_credential, None
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'pending'
|
||||
assert result.auth_credential == openid_connect_credential
|
||||
|
||||
|
||||
def test_openid_connect_with_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential, monkeypatch
|
||||
):
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
'test_access_token',
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
|
||||
mock_auth_handler = MagicMock()
|
||||
mock_auth_handler.get_auth_response.return_value = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
|
||||
)
|
||||
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
|
||||
monkeypatch.setattr(
|
||||
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
|
||||
)
|
||||
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert 'test_access_token' in result.auth_credential.http.credentials.token
|
||||
# Verify that the credential was stored:
|
||||
stored_credential = credential_store.get_credential(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
assert stored_credential == result.auth_credential
|
||||
mock_auth_handler.get_auth_response.assert_called_once()
|
||||
|
||||
|
||||
def test_openid_connect_existing_token(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
_, existing_credential = token_to_scheme_credential(
|
||||
'oauth2Token', 'header', 'bearer', '123123123'
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
# Store the credential to simulate existing credential
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
key = credential_store.get_credential_key(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
credential_store.store_credential(key, existing_credential)
|
||||
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential == existing_credential
|
||||
14
tests/unittests/tools/retrieval/__init__.py
Normal file
14
tests/unittests/tools/retrieval/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
147
tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py
Normal file
147
tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.retrieval.vertex_ai_rag_retrieval import VertexAiRagRetrieval
|
||||
from google.genai import types
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def noop_tool(x: str) -> str:
|
||||
return x
|
||||
|
||||
|
||||
def test_vertex_rag_retrieval_for_gemini_1_x():
|
||||
responses = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
mockModel.model = 'gemini-1.5-pro'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[
|
||||
VertexAiRagRetrieval(
|
||||
name='rag_retrieval',
|
||||
description='rag_retrieval',
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 1
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test1'),
|
||||
]
|
||||
assert len(mockModel.requests[0].config.tools) == 1
|
||||
assert (
|
||||
mockModel.requests[0].config.tools[0].function_declarations[0].name
|
||||
== 'rag_retrieval'
|
||||
)
|
||||
assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None
|
||||
|
||||
|
||||
def test_vertex_rag_retrieval_for_gemini_1_x_with_another_function_tool():
|
||||
responses = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
mockModel.model = 'gemini-1.5-pro'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[
|
||||
VertexAiRagRetrieval(
|
||||
name='rag_retrieval',
|
||||
description='rag_retrieval',
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
],
|
||||
),
|
||||
FunctionTool(func=noop_tool),
|
||||
],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 1
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test1'),
|
||||
]
|
||||
assert len(mockModel.requests[0].config.tools[0].function_declarations) == 2
|
||||
assert (
|
||||
mockModel.requests[0].config.tools[0].function_declarations[0].name
|
||||
== 'rag_retrieval'
|
||||
)
|
||||
assert (
|
||||
mockModel.requests[0].config.tools[0].function_declarations[1].name
|
||||
== 'noop_tool'
|
||||
)
|
||||
assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None
|
||||
|
||||
|
||||
def test_vertex_rag_retrieval_for_gemini_2_x():
|
||||
responses = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
mockModel.model = 'gemini-2.0-flash'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[
|
||||
VertexAiRagRetrieval(
|
||||
name='rag_retrieval',
|
||||
description='rag_retrieval',
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 1
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test1'),
|
||||
]
|
||||
assert len(mockModel.requests[0].config.tools) == 1
|
||||
assert mockModel.requests[0].config.tools == [
|
||||
types.Tool(
|
||||
retrieval=types.Retrieval(
|
||||
vertex_rag_store=types.VertexRagStore(
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
assert 'rag_retrieval' not in mockModel.requests[0].tools_dict
|
||||
167
tests/unittests/tools/test_agent_tool.py
Normal file
167
tests/unittests/tools/test_agent_tool.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.tools.agent_tool import AgentTool
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
from .. import utils
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason='Skipping until tool.func evaluations are fixed (async)'
|
||||
)
|
||||
|
||||
|
||||
function_call_custom = Part.from_function_call(
|
||||
name='tool_agent', args={'custom_input': 'test1'}
|
||||
)
|
||||
|
||||
function_call_no_schema = Part.from_function_call(
|
||||
name='tool_agent', args={'request': 'test1'}
|
||||
)
|
||||
|
||||
function_response_custom = Part.from_function_response(
|
||||
name='tool_agent', response={'custom_output': 'response1'}
|
||||
)
|
||||
|
||||
function_response_no_schema = Part.from_function_response(
|
||||
name='tool_agent', response={'result': 'response1'}
|
||||
)
|
||||
|
||||
|
||||
def change_state_callback(callback_context: CallbackContext):
|
||||
callback_context.state['state_1'] = 'changed_value'
|
||||
print('change_state_callback: ', callback_context.state)
|
||||
|
||||
|
||||
def test_no_schema():
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
function_call_no_schema,
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', function_call_no_schema),
|
||||
('root_agent', function_response_no_schema),
|
||||
('root_agent', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_update_state():
|
||||
"""The agent tool can read and change parent state."""
|
||||
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
function_call_no_schema,
|
||||
'{"custom_output": "response1"}',
|
||||
'response2',
|
||||
]
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
instruction='input: {state_1}',
|
||||
before_agent_callback=change_state_callback,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
runner.session.state['state_1'] = 'state1_value'
|
||||
|
||||
runner.run('test1')
|
||||
assert (
|
||||
'input: changed_value' in mock_model.requests[1].config.system_instruction
|
||||
)
|
||||
assert runner.session.state['state_1'] == 'changed_value'
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
'env_variables',
|
||||
[
|
||||
'GOOGLE_AI',
|
||||
# TODO(wanyif): re-enable after fix.
|
||||
# 'VERTEX',
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_custom_schema():
|
||||
class CustomInput(BaseModel):
|
||||
custom_input: str
|
||||
|
||||
class CustomOutput(BaseModel):
|
||||
custom_output: str
|
||||
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
function_call_custom,
|
||||
'{"custom_output": "response1"}',
|
||||
'response2',
|
||||
]
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
input_schema=CustomInput,
|
||||
output_schema=CustomOutput,
|
||||
output_key='tool_output',
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
runner.session.state['state_1'] = 'state1_value'
|
||||
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', function_call_custom),
|
||||
('root_agent', function_response_custom),
|
||||
('root_agent', 'response2'),
|
||||
]
|
||||
|
||||
assert runner.session.state['tool_output'] == {'custom_output': 'response1'}
|
||||
|
||||
assert len(mock_model.requests) == 3
|
||||
# The second request is the tool agent request.
|
||||
assert mock_model.requests[1].config.response_schema == CustomOutput
|
||||
assert mock_model.requests[1].config.response_mime_type == 'application/json'
|
||||
141
tests/unittests/tools/test_base_tool.py
Normal file
141
tests/unittests/tools/test_base_tool.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
class _TestingTool(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
declaration: Optional[types.FunctionDeclaration] = None,
|
||||
):
|
||||
super().__init__(name='test_tool', description='test_description')
|
||||
self.declaration = declaration
|
||||
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
return self.declaration
|
||||
|
||||
|
||||
def _create_tool_context() -> ToolContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
agent = SequentialAgent(name='test_agent')
|
||||
invocation_context = InvocationContext(
|
||||
invocation_id='invocation_id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
return ToolContext(invocation_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_no_declaration():
|
||||
tool = _TestingTool()
|
||||
tool_context = _create_tool_context()
|
||||
llm_request = LlmRequest()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
assert llm_request.config is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_with_declaration():
|
||||
declaration = types.FunctionDeclaration(
|
||||
name='test_tool',
|
||||
description='test_description',
|
||||
parameters=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
title='param_1',
|
||||
),
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest()
|
||||
tool_context = _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
assert llm_request.config.tools[0].function_declarations == [declaration]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_with_builtin_tool():
|
||||
declaration = types.FunctionDeclaration(
|
||||
name='test_tool',
|
||||
description='test_description',
|
||||
parameters=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
title='param_1',
|
||||
),
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest(
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[types.Tool(google_search=types.GoogleSearch())]
|
||||
)
|
||||
)
|
||||
tool_context = _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
# function_declaration is added to another types.Tool without builtin tool.
|
||||
assert llm_request.config.tools[1].function_declarations == [declaration]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_with_builtin_tool_and_another_declaration():
|
||||
declaration = types.FunctionDeclaration(
|
||||
name='test_tool',
|
||||
description='test_description',
|
||||
parameters=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
title='param_1',
|
||||
),
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest(
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(google_search=types.GoogleSearch()),
|
||||
types.Tool(function_declarations=[types.FunctionDeclaration()]),
|
||||
]
|
||||
)
|
||||
)
|
||||
tool_context = _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
# function_declaration is added to existing types.Tool with function_declaration.
|
||||
assert llm_request.config.tools[1].function_declarations[1] == declaration
|
||||
277
tests/unittests/tools/test_build_function_declaration.py
Normal file
277
tests/unittests/tools/test_build_function_declaration.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from google.adk.tools import _automatic_function_calling_util
|
||||
from google.adk.tools.agent_tool import ToolContext
|
||||
from google.adk.tools.langchain_tool import LangchainTool
|
||||
# TODO: crewai requires python 3.10 as minimum
|
||||
# from crewai_tools import FileReadTool
|
||||
from langchain_community.tools import ShellTool
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
def test_unsupported_variant():
|
||||
def simple_function(input_str: str) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function, variant='Unsupported'
|
||||
)
|
||||
|
||||
|
||||
def test_string_input():
|
||||
def simple_function(input_str: str) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'STRING'
|
||||
|
||||
|
||||
def test_int_input():
|
||||
def simple_function(input_str: int) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'INTEGER'
|
||||
|
||||
|
||||
def test_float_input():
|
||||
def simple_function(input_str: float) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'NUMBER'
|
||||
|
||||
|
||||
def test_bool_input():
|
||||
def simple_function(input_str: bool) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'BOOLEAN'
|
||||
|
||||
|
||||
def test_array_input():
|
||||
def simple_function(input_str: List[str]) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
|
||||
|
||||
|
||||
def test_dict_input():
|
||||
def simple_function(input_str: Dict[str, str]) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'OBJECT'
|
||||
|
||||
|
||||
def test_basemodel_input():
|
||||
class CustomInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
def simple_function(input: CustomInput) -> str:
|
||||
return {'result': input}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input'].type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['input'].properties['input_str'].type
|
||||
== 'STRING'
|
||||
)
|
||||
|
||||
|
||||
def test_toolcontext_ignored():
|
||||
def simple_function(input_str: str, tool_context: ToolContext) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function, ignore_params=['tool_context']
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'STRING'
|
||||
assert 'tool_context' not in function_decl.parameters.properties
|
||||
|
||||
|
||||
def test_basemodel():
|
||||
class SimpleFunction(BaseModel):
|
||||
input_str: str
|
||||
custom_input: int
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=SimpleFunction, ignore_params=['custom_input']
|
||||
)
|
||||
|
||||
assert function_decl.name == 'SimpleFunction'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'STRING'
|
||||
assert 'custom_input' not in function_decl.parameters.properties
|
||||
|
||||
|
||||
def test_nested_basemodel_input():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
child: ChildInput
|
||||
|
||||
def simple_function(input: CustomInput) -> str:
|
||||
return {'result': input}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input'].type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['input'].properties['child'].type
|
||||
== 'OBJECT'
|
||||
)
|
||||
assert (
|
||||
function_decl.parameters.properties['input']
|
||||
.properties['child']
|
||||
.properties['input_str']
|
||||
.type
|
||||
== 'STRING'
|
||||
)
|
||||
|
||||
|
||||
def test_basemodel_with_nested_basemodel():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
child: ChildInput
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=CustomInput, ignore_params=['custom_input']
|
||||
)
|
||||
|
||||
assert function_decl.name == 'CustomInput'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['child'].type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['child'].properties['input_str'].type
|
||||
== 'STRING'
|
||||
)
|
||||
assert 'custom_input' not in function_decl.parameters.properties
|
||||
|
||||
|
||||
def test_list():
|
||||
def simple_function(
|
||||
input_str: List[str], input_dir: List[Dict[str, str]]
|
||||
) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
|
||||
assert function_decl.parameters.properties['input_str'].items.type == 'STRING'
|
||||
assert function_decl.parameters.properties['input_dir'].type == 'ARRAY'
|
||||
assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'
|
||||
|
||||
|
||||
def test_basemodel_list():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
child: ChildInput
|
||||
|
||||
def simple_function(input_str: List[CustomInput]) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
|
||||
assert function_decl.parameters.properties['input_str'].items.type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['input_str']
|
||||
.items.properties['child']
|
||||
.type
|
||||
== 'OBJECT'
|
||||
)
|
||||
assert (
|
||||
function_decl.parameters.properties['input_str']
|
||||
.items.properties['child']
|
||||
.properties['input_str']
|
||||
.type
|
||||
== 'STRING'
|
||||
)
|
||||
|
||||
|
||||
# TODO: comment out this test for now as crewai requires python 3.10 as minimum
|
||||
# def test_crewai_tool():
|
||||
# docs_tool = CrewaiTool(
|
||||
# name='direcotry_read_tool',
|
||||
# description='use this to find files for you.',
|
||||
# tool=FileReadTool(),
|
||||
# )
|
||||
# function_decl = docs_tool.get_declaration()
|
||||
# assert function_decl.name == 'direcotry_read_tool'
|
||||
# assert function_decl.parameters.type == 'OBJECT'
|
||||
# assert function_decl.parameters.properties['file_path'].type == 'STRING'
|
||||
Reference in New Issue
Block a user