Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

View File

@@ -0,0 +1,600 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest import mock
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
import google.auth
import pytest
import requests
from requests import exceptions
@pytest.fixture
def project():
return "test-project"
@pytest.fixture
def location():
return "us-central1"
@pytest.fixture
def connection_name():
return "test-connection"
@pytest.fixture
def mock_credentials():
creds = mock.create_autospec(google.auth.credentials.Credentials)
creds.token = "test_token"
creds.expired = False
return creds
@pytest.fixture
def mock_auth_request():
return mock.create_autospec(google.auth.transport.requests.Request)
class TestConnectionsClient:
def test_initialization(self, project, location, connection_name):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(
project, location, connection_name, json.dumps(credentials)
)
assert client.project == project
assert client.location == location
assert client.connection == connection_name
assert client.connector_url == "https://connectors.googleapis.com"
assert client.service_account_json == json.dumps(credentials)
assert client.credential_cache is None
def test_execute_api_call_success(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_response = mock.MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"data": "test"}
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch("requests.get", return_value=mock_response):
response = client._execute_api_call("https://test.url")
assert response.json() == {"data": "test"}
requests.get.assert_called_once_with(
"https://test.url",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {mock_credentials.token}",
},
)
def test_execute_api_call_credential_error(
self, project, location, connection_name
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object(
client,
"_get_access_token",
side_effect=google.auth.exceptions.DefaultCredentialsError("Test"),
):
with pytest.raises(PermissionError, match="Credentials error: Test"):
client._execute_api_call("https://test.url")
@pytest.mark.parametrize(
"status_code, response_text",
[(404, "Not Found"), (400, "Bad Request")],
)
def test_execute_api_call_request_error_not_found_or_bad_request(
self,
project,
location,
connection_name,
mock_credentials,
status_code,
response_text,
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_response = mock.MagicMock()
mock_response.status_code = status_code
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
f"HTTP error {status_code}: {response_text}"
)
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch("requests.get", return_value=mock_response):
with pytest.raises(
ValueError, match="Invalid request. Please check the provided"
):
client._execute_api_call("https://test.url")
def test_execute_api_call_other_request_error(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_response = mock.MagicMock()
mock_response.status_code = 500
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
"Internal Server Error"
)
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch("requests.get", return_value=mock_response):
with pytest.raises(ValueError, match="Request error: "):
client._execute_api_call("https://test.url")
def test_execute_api_call_unexpected_error(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch(
"requests.get", side_effect=Exception("Something went wrong")
):
with pytest.raises(
Exception, match="An unexpected error occurred: Something went wrong"
):
client._execute_api_call("https://test.url")
def test_get_connection_details_success_with_host(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_response = mock.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"serviceDirectory": "test_service",
"host": "test.host",
"tlsServiceDirectory": "tls_test_service",
"authOverrideEnabled": True,
}
with mock.patch.object(
client, "_execute_api_call", return_value=mock_response
):
details = client.get_connection_details()
assert details == {
"serviceName": "tls_test_service",
"host": "test.host",
"authOverrideEnabled": True,
}
def test_get_connection_details_success_without_host(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_response = mock.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"serviceDirectory": "test_service",
"authOverrideEnabled": False,
}
with mock.patch.object(
client, "_execute_api_call", return_value=mock_response
):
details = client.get_connection_details()
assert details == {
"serviceName": "test_service",
"host": "",
"authOverrideEnabled": False,
}
def test_get_connection_details_error(
self, project, location, connection_name
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object(
client, "_execute_api_call", side_effect=ValueError("Request error")
):
with pytest.raises(ValueError, match="Request error"):
client.get_connection_details()
def test_get_entity_schema_and_operations_success(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_execute_response_initial = mock.MagicMock()
mock_execute_response_initial.status_code = 200
mock_execute_response_initial.json.return_value = {
"name": "operations/test_op"
}
mock_execute_response_poll_done = mock.MagicMock()
mock_execute_response_poll_done.status_code = 200
mock_execute_response_poll_done.json.return_value = {
"done": True,
"response": {
"jsonSchema": {"type": "object"},
"operations": ["LIST", "GET"],
},
}
with mock.patch.object(
client,
"_execute_api_call",
side_effect=[
mock_execute_response_initial,
mock_execute_response_poll_done,
],
):
schema, operations = client.get_entity_schema_and_operations("entity1")
assert schema == {"type": "object"}
assert operations == ["LIST", "GET"]
assert (
mock.call(
f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getEntityType?entityId=entity1"
)
in client._execute_api_call.mock_calls
)
assert (
mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
in client._execute_api_call.mock_calls
)
def test_get_entity_schema_and_operations_no_operation_id(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_execute_response = mock.MagicMock()
mock_execute_response.status_code = 200
mock_execute_response.json.return_value = {}
with mock.patch.object(
client, "_execute_api_call", return_value=mock_execute_response
):
with pytest.raises(
ValueError,
match=(
"Failed to get entity schema and operations for entity: entity1"
),
):
client.get_entity_schema_and_operations("entity1")
def test_get_entity_schema_and_operations_execute_api_call_error(
self, project, location, connection_name
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object(
client, "_execute_api_call", side_effect=ValueError("Request error")
):
with pytest.raises(ValueError, match="Request error"):
client.get_entity_schema_and_operations("entity1")
def test_get_action_schema_success(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_execute_response_initial = mock.MagicMock()
mock_execute_response_initial.status_code = 200
mock_execute_response_initial.json.return_value = {
"name": "operations/test_op"
}
mock_execute_response_poll_done = mock.MagicMock()
mock_execute_response_poll_done.status_code = 200
mock_execute_response_poll_done.json.return_value = {
"done": True,
"response": {
"inputJsonSchema": {
"type": "object",
"properties": {"input": {"type": "string"}},
},
"outputJsonSchema": {
"type": "object",
"properties": {"output": {"type": "string"}},
},
"description": "Test Action Description",
"displayName": "TestAction",
},
}
with mock.patch.object(
client,
"_execute_api_call",
side_effect=[
mock_execute_response_initial,
mock_execute_response_poll_done,
],
):
schema = client.get_action_schema("action1")
assert schema == {
"inputSchema": {
"type": "object",
"properties": {"input": {"type": "string"}},
},
"outputSchema": {
"type": "object",
"properties": {"output": {"type": "string"}},
},
"description": "Test Action Description",
"displayName": "TestAction",
}
assert (
mock.call(
f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getAction?actionId=action1"
)
in client._execute_api_call.mock_calls
)
assert (
mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
in client._execute_api_call.mock_calls
)
def test_get_action_schema_no_operation_id(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
mock_execute_response = mock.MagicMock()
mock_execute_response.status_code = 200
mock_execute_response.json.return_value = {}
with mock.patch.object(
client, "_execute_api_call", return_value=mock_execute_response
):
with pytest.raises(
ValueError, match="Failed to get action schema for action: action1"
):
client.get_action_schema("action1")
def test_get_action_schema_execute_api_call_error(
self, project, location, connection_name
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object(
client, "_execute_api_call", side_effect=ValueError("Request error")
):
with pytest.raises(ValueError, match="Request error"):
client.get_action_schema("action1")
def test_get_connector_base_spec(self):
spec = ConnectionsClient.get_connector_base_spec()
assert "openapi" in spec
assert spec["info"]["title"] == "ExecuteConnection"
assert "components" in spec
assert "schemas" in spec["components"]
assert "operation" in spec["components"]["schemas"]
def test_get_action_operation(self):
operation = ConnectionsClient.get_action_operation(
"TestAction", "EXECUTE_ACTION", "TestActionDisplayName", "test_tool"
)
assert "post" in operation
assert operation["post"]["summary"] == "TestActionDisplayName"
assert "operationId" in operation["post"]
assert operation["post"]["operationId"] == "test_tool_TestActionDisplayName"
def test_list_operation(self):
operation = ConnectionsClient.list_operation(
"Entity1", '{"type": "object"}', "test_tool"
)
assert "post" in operation
assert operation["post"]["summary"] == "List Entity1"
assert "operationId" in operation["post"]
assert operation["post"]["operationId"] == "test_tool_list_Entity1"
def test_get_operation_static(self):
operation = ConnectionsClient.get_operation(
"Entity1", '{"type": "object"}', "test_tool"
)
assert "post" in operation
assert operation["post"]["summary"] == "Get Entity1"
assert "operationId" in operation["post"]
assert operation["post"]["operationId"] == "test_tool_get_Entity1"
def test_create_operation(self):
operation = ConnectionsClient.create_operation("Entity1", "test_tool")
assert "post" in operation
assert operation["post"]["summary"] == "Create Entity1"
assert "operationId" in operation["post"]
assert operation["post"]["operationId"] == "test_tool_create_Entity1"
def test_update_operation(self):
operation = ConnectionsClient.update_operation("Entity1", "test_tool")
assert "post" in operation
assert operation["post"]["summary"] == "Update Entity1"
assert "operationId" in operation["post"]
assert operation["post"]["operationId"] == "test_tool_update_Entity1"
def test_delete_operation(self):
operation = ConnectionsClient.delete_operation("Entity1", "test_tool")
assert "post" in operation
assert operation["post"]["summary"] == "Delete Entity1"
assert operation["post"]["operationId"] == "test_tool_delete_Entity1"
def test_create_operation_request(self):
schema = ConnectionsClient.create_operation_request("Entity1")
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "connectorInputPayload" in schema["properties"]
def test_update_operation_request(self):
schema = ConnectionsClient.update_operation_request("Entity1")
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "entityId" in schema["properties"]
def test_get_operation_request_static(self):
schema = ConnectionsClient.get_operation_request()
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "entityId" in schema["properties"]
def test_delete_operation_request(self):
schema = ConnectionsClient.delete_operation_request()
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "entityId" in schema["properties"]
def test_list_operation_request(self):
schema = ConnectionsClient.list_operation_request()
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "filterClause" in schema["properties"]
def test_action_request(self):
schema = ConnectionsClient.action_request("TestAction")
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "connectorInputPayload" in schema["properties"]
def test_action_response(self):
schema = ConnectionsClient.action_response("TestAction")
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "connectorOutputPayload" in schema["properties"]
def test_execute_custom_query_request(self):
schema = ConnectionsClient.execute_custom_query_request()
assert "type" in schema
assert schema["type"] == "object"
assert "properties" in schema
assert "query" in schema["properties"]
def test_connector_payload(self):
client = ConnectionsClient("test-project", "us-central1", "test-connection")
schema = client.connector_payload(
json_schema={
"type": "object",
"properties": {
"input": {
"type": ["null", "string"],
"description": "description",
}
},
}
)
assert schema == {
"type": "object",
"properties": {
"input": {
"type": "string",
"nullable": True,
"description": "description",
}
},
}
def test_get_access_token_uses_cached_token(
self, project, location, connection_name, mock_credentials
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
client.credential_cache = mock_credentials
token = client._get_access_token()
assert token == "test_token"
def test_get_access_token_with_service_account_credentials(
self, project, location, connection_name
):
service_account_json = json.dumps({
"client_email": "test@example.com",
"private_key": "test_key",
})
client = ConnectionsClient(
project, location, connection_name, service_account_json
)
mock_creds = mock.create_autospec(google.oauth2.service_account.Credentials)
mock_creds.token = "sa_token"
mock_creds.expired = False
with mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=mock_creds,
), mock.patch.object(mock_creds, "refresh", return_value=None):
token = client._get_access_token()
assert token == "sa_token"
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
json.loads(service_account_json),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
mock_creds.refresh.assert_called_once()
def test_get_access_token_with_default_credentials(
self, project, location, connection_name, mock_credentials
):
client = ConnectionsClient(project, location, connection_name, None)
with mock.patch(
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
), mock.patch.object(mock_credentials, "refresh", return_value=None):
token = client._get_access_token()
assert token == "test_token"
def test_get_access_token_no_valid_credentials(
self, project, location, connection_name
):
client = ConnectionsClient(project, location, connection_name, None)
with mock.patch(
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
return_value=(None, None),
):
with pytest.raises(
ValueError,
match=(
"Please provide a service account that has the required"
" permissions"
),
):
client._get_access_token()
def test_get_access_token_refreshes_expired_token(
self, project, location, connection_name, mock_credentials
):
client = ConnectionsClient(project, location, connection_name, None)
mock_credentials.expired = True
mock_credentials.token = "old_token"
mock_credentials.refresh.return_value = None
client.credential_cache = mock_credentials
with mock.patch(
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
):
# Mock the refresh method directly on the instance within the context
with mock.patch.object(mock_credentials, "refresh") as mock_refresh:
mock_credentials.token = "new_token" # Set the expected new token
token = client._get_access_token()
assert token == "new_token"
mock_refresh.assert_called_once()

View File

@@ -0,0 +1,630 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest import mock
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
import google.auth
import google.auth.transport.requests
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import pytest
import requests
from requests import exceptions
@pytest.fixture
def project():
return "test-project"
@pytest.fixture
def location():
return "us-central1"
@pytest.fixture
def integration_name():
return "test-integration"
@pytest.fixture
def trigger_name():
return "test-trigger"
@pytest.fixture
def connection_name():
return "test-connection"
@pytest.fixture
def mock_credentials():
creds = mock.create_autospec(google.auth.credentials.Credentials)
creds.token = "test_token"
return creds
@pytest.fixture
def mock_auth_request():
return mock.create_autospec(Request)
@pytest.fixture
def mock_connections_client():
with mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.ConnectionsClient"
) as mock_client:
mock_instance = mock.create_autospec(ConnectionsClient)
mock_client.return_value = mock_instance
yield mock_client
class TestIntegrationClient:
def test_initialization(
self, project, location, integration_name, trigger_name, connection_name
):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=connection_name,
entity_operations={"entity": ["LIST"]},
actions=["action1"],
service_account_json=json.dumps({"email": "test@example.com"}),
)
assert client.project == project
assert client.location == location
assert client.integration == integration_name
assert client.trigger == trigger_name
assert client.connection == connection_name
assert client.entity_operations == {"entity": ["LIST"]}
assert client.actions == ["action1"]
assert client.service_account_json == json.dumps(
{"email": "test@example.com"}
)
assert client.credential_cache is None
def test_get_openapi_spec_for_integration_success(
self,
project,
location,
integration_name,
trigger_name,
mock_credentials,
mock_connections_client,
):
expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
mock_response = mock.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch("requests.post", return_value=mock_response):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=None,
entity_operations=None,
actions=None,
service_account_json=None,
)
spec = client.get_openapi_spec_for_integration()
assert spec == expected_spec
requests.post.assert_called_once_with(
f"https://{location}-integrations.googleapis.com/v1/projects/{project}/locations/{location}:generateOpenApiSpec",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {mock_credentials.token}",
},
json={
"apiTriggerResources": [{
"integrationResource": integration_name,
"triggerId": [trigger_name],
}],
"fileFormat": "JSON",
},
)
def test_get_openapi_spec_for_integration_credential_error(
self,
project,
location,
integration_name,
trigger_name,
mock_connections_client,
):
with mock.patch.object(
IntegrationClient,
"_get_access_token",
side_effect=ValueError(
"Please provide a service account that has the required permissions"
" to access the connection."
),
):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=None,
entity_operations=None,
actions=None,
service_account_json=None,
)
with pytest.raises(
Exception,
match=(
"An unexpected error occurred: Please provide a service account"
" that has the required permissions to access the connection."
),
):
client.get_openapi_spec_for_integration()
@pytest.mark.parametrize(
"status_code, response_text",
[(404, "Not Found"), (400, "Bad Request"), (404, ""), (400, "")],
)
def test_get_openapi_spec_for_integration_request_error_not_found_or_bad_request(
self,
project,
location,
integration_name,
trigger_name,
mock_credentials,
status_code,
response_text,
mock_connections_client,
):
mock_response = mock.MagicMock()
mock_response.status_code = status_code
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
f"HTTP error {status_code}: {response_text}"
)
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch("requests.post", return_value=mock_response):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=None,
entity_operations=None,
actions=None,
service_account_json=None,
)
with pytest.raises(
ValueError,
match=(
"Invalid request. Please check the provided values of"
f" project\\({project}\\), location\\({location}\\),"
f" integration\\({integration_name}\\) and"
f" trigger\\({trigger_name}\\)."
),
):
client.get_openapi_spec_for_integration()
def test_get_openapi_spec_for_integration_other_request_error(
self,
project,
location,
integration_name,
trigger_name,
mock_credentials,
mock_connections_client,
):
mock_response = mock.MagicMock()
mock_response.status_code = 500
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
"Internal Server Error"
)
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch("requests.post", return_value=mock_response):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=None,
entity_operations=None,
actions=None,
service_account_json=None,
)
with pytest.raises(ValueError, match="Request error: "):
client.get_openapi_spec_for_integration()
def test_get_openapi_spec_for_integration_unexpected_error(
self,
project,
location,
integration_name,
trigger_name,
mock_credentials,
mock_connections_client,
):
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch(
"requests.post", side_effect=Exception("Something went wrong")
):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=None,
entity_operations=None,
actions=None,
service_account_json=None,
)
with pytest.raises(
Exception, match="An unexpected error occurred: Something went wrong"
):
client.get_openapi_spec_for_integration()
def test_get_openapi_spec_for_connection_no_entity_operations_or_actions(
self, project, location, connection_name, mock_connections_client
):
client = IntegrationClient(
project=project,
location=location,
integration=None,
trigger=None,
connection=connection_name,
entity_operations=None,
actions=None,
service_account_json=None,
)
with pytest.raises(
ValueError,
match=(
"No entity operations or actions provided. Please provide at least"
" one of them."
),
):
client.get_openapi_spec_for_connection()
def test_get_openapi_spec_for_connection_with_entity_operations(
self, project, location, connection_name, mock_connections_client
):
entity_operations = {"entity1": ["LIST", "GET"]}
mock_connections_client_instance = mock_connections_client.return_value
mock_connections_client_instance.get_connector_base_spec.return_value = {
"components": {"schemas": {}},
"paths": {},
}
mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
{"type": "object", "properties": {"id": {"type": "string"}}},
["LIST", "GET"],
)
mock_connections_client_instance.connector_payload.return_value = {
"type": "object"
}
mock_connections_client_instance.list_operation.return_value = {"get": {}}
mock_connections_client_instance.list_operation_request.return_value = {
"type": "object"
}
mock_connections_client_instance.get_operation.return_value = {"get": {}}
mock_connections_client_instance.get_operation_request.return_value = {
"type": "object"
}
client = IntegrationClient(
project=project,
location=location,
integration=None,
trigger=None,
connection=connection_name,
entity_operations=entity_operations,
actions=None,
service_account_json=None,
)
spec = client.get_openapi_spec_for_connection()
assert "paths" in spec
assert (
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#list_entity1"
in spec["paths"]
)
assert (
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#get_entity1"
in spec["paths"]
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client_instance.get_connector_base_spec.assert_called_once()
mock_connections_client_instance.get_entity_schema_and_operations.assert_any_call(
"entity1"
)
mock_connections_client_instance.connector_payload.assert_any_call(
{"type": "object", "properties": {"id": {"type": "string"}}}
)
mock_connections_client_instance.list_operation.assert_called_once()
mock_connections_client_instance.get_operation.assert_called_once()
def test_get_openapi_spec_for_connection_with_actions(
self, project, location, connection_name, mock_connections_client
):
actions = ["TestAction"]
mock_connections_client_instance = (
mock_connections_client.return_value
) # Corrected line
mock_connections_client_instance.get_connector_base_spec.return_value = {
"components": {"schemas": {}},
"paths": {},
}
mock_connections_client_instance.get_action_schema.return_value = {
"inputSchema": {
"type": "object",
"properties": {"input": {"type": "string"}},
},
"outputSchema": {
"type": "object",
"properties": {"output": {"type": "string"}},
},
"displayName": "TestAction",
}
mock_connections_client_instance.connector_payload.side_effect = [
{"type": "object"},
{"type": "object"},
]
mock_connections_client_instance.action_request.return_value = {
"type": "object"
}
mock_connections_client_instance.action_response.return_value = {
"type": "object"
}
mock_connections_client_instance.get_action_operation.return_value = {
"post": {}
}
client = IntegrationClient(
project=project,
location=location,
integration=None,
trigger=None,
connection=connection_name,
entity_operations=None,
actions=actions,
service_account_json=None,
)
spec = client.get_openapi_spec_for_connection()
assert "paths" in spec
assert (
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#TestAction"
in spec["paths"]
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client_instance.get_connector_base_spec.assert_called_once()
mock_connections_client_instance.get_action_schema.assert_called_once_with(
"TestAction"
)
mock_connections_client_instance.connector_payload.assert_any_call(
{"type": "object", "properties": {"input": {"type": "string"}}}
)
mock_connections_client_instance.connector_payload.assert_any_call(
{"type": "object", "properties": {"output": {"type": "string"}}}
)
mock_connections_client_instance.action_request.assert_called_once_with(
"TestAction"
)
mock_connections_client_instance.action_response.assert_called_once_with(
"TestAction"
)
mock_connections_client_instance.get_action_operation.assert_called_once()
def test_get_openapi_spec_for_connection_invalid_operation(
self, project, location, connection_name, mock_connections_client
):
entity_operations = {"entity1": ["INVALID"]}
mock_connections_client_instance = mock_connections_client.return_value
mock_connections_client_instance.get_connector_base_spec.return_value = {
"components": {"schemas": {}},
"paths": {},
}
mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
{"type": "object", "properties": {"id": {"type": "string"}}},
["LIST", "GET"],
)
client = IntegrationClient(
project=project,
location=location,
integration=None,
trigger=None,
connection=connection_name,
entity_operations=entity_operations,
actions=None,
service_account_json=None,
)
with pytest.raises(
ValueError, match="Invalid operation: INVALID for entity: entity1"
):
client.get_openapi_spec_for_connection()
def test_get_access_token_with_service_account_json(
self, project, location, integration_name, trigger_name, connection_name
):
service_account_json = json.dumps({
"client_email": "test@example.com",
"private_key": "test_key",
})
mock_creds = mock.create_autospec(service_account.Credentials)
mock_creds.token = "sa_token"
mock_creds.expired = False
with mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=mock_creds,
), mock.patch.object(mock_creds, "refresh", return_value=None):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=connection_name,
entity_operations=None,
actions=None,
service_account_json=service_account_json,
)
token = client._get_access_token()
assert token == "sa_token"
service_account.Credentials.from_service_account_info.assert_called_once_with(
json.loads(service_account_json),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
mock_creds.refresh.assert_called_once()
def test_get_access_token_with_default_credentials(
self,
project,
location,
integration_name,
trigger_name,
connection_name,
mock_credentials,
):
mock_credentials.expired = False
with mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
), mock.patch.object(mock_credentials, "refresh", return_value=None):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=connection_name,
entity_operations=None,
actions=None,
service_account_json=None,
)
token = client._get_access_token()
assert token == "test_token"
def test_get_access_token_no_valid_credentials(
self, project, location, integration_name, trigger_name, connection_name
):
with mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(None, None),
), mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=None,
):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=connection_name,
entity_operations=None,
actions=None,
service_account_json=None,
)
try:
client._get_access_token()
assert False, "ValueError was not raised" # Explicitly fail if no error
except ValueError as e:
assert (
"Please provide a service account that has the required permissions"
" to access the connection."
in str(e)
)
def test_get_access_token_uses_cached_token(
self,
project,
location,
integration_name,
trigger_name,
connection_name,
mock_credentials,
):
mock_credentials.token = "cached_token"
mock_credentials.expired = False
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=connection_name,
entity_operations=None,
actions=None,
service_account_json=None,
)
client.credential_cache = mock_credentials # Simulate a cached credential
with mock.patch("google.auth.default") as mock_default, mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info"
) as mock_sa:
token = client._get_access_token()
assert token == "cached_token"
mock_default.assert_not_called()
mock_sa.assert_not_called()
def test_get_access_token_refreshes_expired_token(
self,
project,
location,
integration_name,
trigger_name,
connection_name,
mock_credentials,
):
mock_credentials = mock.create_autospec(google.auth.credentials.Credentials)
mock_credentials.token = "old_token"
mock_credentials.expired = True
mock_credentials.refresh.return_value = None
mock_credentials.token = "new_token" # Simulate token refresh
with mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
):
client = IntegrationClient(
project=project,
location=location,
integration=integration_name,
trigger=trigger_name,
connection=connection_name,
entity_operations=None,
actions=None,
service_account_json=None,
)
client.credential_cache = mock_credentials
token = client._get_access_token()
assert token == "new_token"
mock_credentials.refresh.assert_called_once()