diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index 6855b18..c25f5ba 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -46,7 +46,7 @@ class ApplicationIntegrationToolset: project="test-project", location="us-central1" integration="test-integration", - trigger="api_trigger/test_trigger", + triggers=["api_trigger/test_trigger"], service_account_credentials={...}, ) @@ -80,7 +80,7 @@ class ApplicationIntegrationToolset: project: str, location: str, integration: Optional[str] = None, - trigger: Optional[str] = None, + triggers: Optional[List[str]] = None, connection: Optional[str] = None, entity_operations: Optional[str] = None, actions: Optional[str] = None, @@ -95,10 +95,11 @@ class ApplicationIntegrationToolset: ): """Args: + Args: project: The GCP project ID. location: The GCP location. integration: The integration name. - trigger: The trigger name. + triggers: The list of trigger names in the integration. connection: The connection name. entity_operations: The entity operations supported by the connection. actions: The actions supported by the connection. @@ -112,15 +113,17 @@ class ApplicationIntegrationToolset: expose. Raises: - ValueError: If neither integration and trigger nor connection and - (entity_operations or actions) is provided. + ValueError: If none of the following conditions are met: + - `integration` is provided. + - `connection` is provided and at least one of `entity_operations` + or `actions` is provided. Exception: If there is an error during the initialization of the integration or connection client. """ self.project = project self.location = location self.integration = integration - self.trigger = trigger + self.triggers = triggers self.connection = connection self.entity_operations = entity_operations self.actions = actions @@ -133,14 +136,14 @@ class ApplicationIntegrationToolset: project, location, integration, - trigger, + triggers, connection, entity_operations, actions, service_account_json, ) connection_details = {} - if integration and trigger: + if integration: spec = integration_client.get_openapi_spec_for_integration() elif connection and (entity_operations or actions): connections_client = ConnectionsClient( @@ -153,7 +156,7 @@ class ApplicationIntegrationToolset: ) else: raise ValueError( - "Either (integration and trigger) or (connection and" + "Invalid request, Either integration or (connection and" " (entity_operations or actions)) should be provided." ) self.openapi_toolset = None @@ -183,7 +186,7 @@ class ApplicationIntegrationToolset: ) auth_scheme = HTTPBearer(bearerFormat="JWT") - if self.integration and self.trigger: + if self.integration: self.openapi_toolset = OpenAPIToolset( spec_dict=spec_dict, auth_credential=auth_credential, diff --git a/src/google/adk/tools/application_integration_tool/clients/integration_client.py b/src/google/adk/tools/application_integration_tool/clients/integration_client.py index 8030ffa..d74dccf 100644 --- a/src/google/adk/tools/application_integration_tool/clients/integration_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/integration_client.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import Optional +from typing import List, Optional from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient import google.auth from google.auth import default as default_service_credential @@ -35,7 +35,7 @@ class IntegrationClient: project: str, location: str, integration: Optional[str] = None, - trigger: Optional[str] = None, + triggers: Optional[List[str]] = None, connection: Optional[str] = None, entity_operations: Optional[dict[str, list[str]]] = None, actions: Optional[list[str]] = None, @@ -47,7 +47,7 @@ class IntegrationClient: project: The Google Cloud project ID. location: The Google Cloud location (e.g., us-central1). integration: The integration name. - trigger: The trigger ID for the integration. + triggers: The list of trigger IDs for the integration. connection: The connection name. entity_operations: A dictionary mapping entity names to a list of operations (e.g., LIST, CREATE, UPDATE, DELETE, GET). @@ -59,7 +59,7 @@ class IntegrationClient: self.project = project self.location = location self.integration = integration - self.trigger = trigger + self.triggers = triggers self.connection = connection self.entity_operations = ( entity_operations if entity_operations is not None else {} @@ -88,7 +88,7 @@ class IntegrationClient: "apiTriggerResources": [ { "integrationResource": self.integration, - "triggerId": [self.trigger], + "triggerId": self.triggers, }, ], "fileFormat": "JSON", @@ -109,7 +109,7 @@ class IntegrationClient: raise ValueError( "Invalid request. Please check the provided values of" f" project({self.project}), location({self.location})," - f" integration({self.integration}) and trigger({self.trigger})." + f" integration({self.integration})." ) from e raise ValueError(f"Request error: {e}") from e except Exception as e: diff --git a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py index e58377e..e672925 100644 --- a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +++ b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import re from unittest import mock from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient @@ -42,8 +43,8 @@ def integration_name(): @pytest.fixture -def trigger_name(): - return "test-trigger" +def triggers(): + return ["test-trigger", "test-trigger2"] @pytest.fixture @@ -76,13 +77,13 @@ def mock_connections_client(): class TestIntegrationClient: def test_initialization( - self, project, location, integration_name, trigger_name, connection_name + self, project, location, integration_name, triggers, connection_name ): client = IntegrationClient( project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=connection_name, entity_operations={"entity": ["LIST"]}, actions=["action1"], @@ -91,7 +92,7 @@ class TestIntegrationClient: assert client.project == project assert client.location == location assert client.integration == integration_name - assert client.trigger == trigger_name + assert client.triggers == triggers assert client.connection == connection_name assert client.entity_operations == {"entity": ["LIST"]} assert client.actions == ["action1"] @@ -105,7 +106,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, mock_credentials, mock_connections_client, ): @@ -126,7 +127,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=None, entity_operations=None, actions=None, @@ -143,7 +144,7 @@ class TestIntegrationClient: json={ "apiTriggerResources": [{ "integrationResource": integration_name, - "triggerId": [trigger_name], + "triggerId": triggers, }], "fileFormat": "JSON", }, @@ -154,7 +155,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, mock_connections_client, ): with mock.patch.object( @@ -169,7 +170,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=None, entity_operations=None, actions=None, @@ -193,7 +194,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, mock_credentials, status_code, response_text, @@ -217,7 +218,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=None, entity_operations=None, actions=None, @@ -226,10 +227,9 @@ class TestIntegrationClient: with pytest.raises( ValueError, match=( - "Invalid request. Please check the provided values of" - f" project\\({project}\\), location\\({location}\\)," - f" integration\\({integration_name}\\) and" - f" trigger\\({trigger_name}\\)." + r"Invalid request\. Please check the provided values of" + rf" project\({project}\), location\({location}\)," + rf" integration\({integration_name}\)." ), ): client.get_openapi_spec_for_integration() @@ -239,7 +239,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, mock_credentials, mock_connections_client, ): @@ -261,7 +261,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=None, entity_operations=None, actions=None, @@ -275,7 +275,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, mock_credentials, mock_connections_client, ): @@ -293,7 +293,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=None, entity_operations=None, actions=None, @@ -311,7 +311,7 @@ class TestIntegrationClient: project=project, location=location, integration=None, - trigger=None, + triggers=None, connection=connection_name, entity_operations=None, actions=None, @@ -356,7 +356,7 @@ class TestIntegrationClient: project=project, location=location, integration=None, - trigger=None, + triggers=None, connection=connection_name, entity_operations=entity_operations, actions=None, @@ -425,7 +425,7 @@ class TestIntegrationClient: project=project, location=location, integration=None, - trigger=None, + triggers=None, connection=connection_name, entity_operations=None, actions=actions, @@ -476,7 +476,7 @@ class TestIntegrationClient: project=project, location=location, integration=None, - trigger=None, + triggers=None, connection=connection_name, entity_operations=entity_operations, actions=None, @@ -488,7 +488,7 @@ class TestIntegrationClient: client.get_openapi_spec_for_connection() def test_get_access_token_with_service_account_json( - self, project, location, integration_name, trigger_name, connection_name + self, project, location, integration_name, triggers, connection_name ): service_account_json = json.dumps({ "client_email": "test@example.com", @@ -509,7 +509,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=connection_name, entity_operations=None, actions=None, @@ -528,7 +528,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, connection_name, mock_credentials, ): @@ -544,7 +544,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=connection_name, entity_operations=None, actions=None, @@ -554,7 +554,7 @@ class TestIntegrationClient: assert token == "test_token" def test_get_access_token_no_valid_credentials( - self, project, location, integration_name, trigger_name, connection_name + self, project, location, integration_name, triggers, connection_name ): with ( mock.patch( @@ -570,7 +570,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=connection_name, entity_operations=None, actions=None, @@ -591,7 +591,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, connection_name, mock_credentials, ): @@ -601,7 +601,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=connection_name, entity_operations=None, actions=None, @@ -624,7 +624,7 @@ class TestIntegrationClient: project, location, integration_name, - trigger_name, + triggers, connection_name, mock_credentials, ): @@ -642,7 +642,7 @@ class TestIntegrationClient: project=project, location=location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, connection=connection_name, entity_operations=None, actions=None, diff --git a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py index 28dbb9d..0a707e6 100644 --- a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +++ b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py @@ -52,6 +52,24 @@ def mock_openapi_toolset(): yield mock_toolset +@pytest.fixture +def mock_openapi_toolset_with_multiple_tools_and_no_tools(): + with mock.patch( + "google.adk.tools.application_integration_tool.application_integration_toolset.OpenAPIToolset" + ) as mock_toolset: + mock_toolset_instance = mock.MagicMock() + mock_rest_api_tool = mock.MagicMock(spec=rest_api_tool.RestApiTool) + mock_rest_api_tool.name = "Test Tool" + mock_rest_api_tool_2 = mock.MagicMock(spec=rest_api_tool.RestApiTool) + mock_rest_api_tool_2.name = "Test Tool 2" + mock_toolset_instance.get_tools.return_value = [ + mock_rest_api_tool, + mock_rest_api_tool_2, + ] + mock_toolset.return_value = mock_toolset_instance + yield mock_toolset + + def get_mocked_parsed_operation(operation_id, attributes): mock_openapi_spec_parser_instance = mock.MagicMock() mock_parsed_operation = mock.MagicMock(spec=ParsedOperation) @@ -142,12 +160,12 @@ def test_initialization_with_integration_and_trigger( mock_openapi_toolset, ): integration_name = "test-integration" - trigger_name = "test-trigger" + triggers = ["test-trigger"] toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, trigger=trigger_name + project, location, integration=integration_name, triggers=triggers ) mock_integration_client.assert_called_once_with( - project, location, integration_name, trigger_name, None, None, None, None + project, location, integration_name, triggers, None, None, None, None ) mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once() mock_connections_client.assert_not_called() @@ -156,6 +174,58 @@ def test_initialization_with_integration_and_trigger( assert toolset.get_tools()[0].name == "Test Tool" +def test_initialization_with_integration_and_list_of_triggers( + project, + location, + mock_integration_client, + mock_connections_client, + mock_openapi_toolset_with_multiple_tools_and_no_tools, +): + integration_name = "test-integration" + triggers = ["test-trigger1", "test-trigger2"] + toolset = ApplicationIntegrationToolset( + project, location, integration=integration_name, triggers=triggers + ) + mock_integration_client.assert_called_once_with( + project, + location, + integration_name, + triggers, + None, + None, + None, + None, + ) + mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once() + mock_connections_client.assert_not_called() + mock_openapi_toolset_with_multiple_tools_and_no_tools.assert_called_once() + assert len(toolset.get_tools()) == 2 + assert toolset.get_tools()[0].name == "Test Tool" + assert toolset.get_tools()[1].name == "Test Tool 2" + + +def test_initialization_with_integration_and_empty_trigger_list( + project, + location, + mock_integration_client, + mock_connections_client, + mock_openapi_toolset_with_multiple_tools_and_no_tools, +): + integration_name = "test-integration" + toolset = ApplicationIntegrationToolset( + project, location, integration=integration_name + ) + mock_integration_client.assert_called_once_with( + project, location, integration_name, None, None, None, None, None + ) + mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once() + mock_connections_client.assert_not_called() + mock_openapi_toolset_with_multiple_tools_and_no_tools.assert_called_once() + assert len(toolset.get_tools()) == 2 + assert toolset.get_tools()[0].name == "Test Tool" + assert toolset.get_tools()[1].name == "Test Tool 2" + + def test_initialization_with_connection_and_entity_operations( project, location, @@ -250,7 +320,7 @@ def test_initialization_without_required_params(project, location): with pytest.raises( ValueError, match=( - "Either \\(integration and trigger\\) or \\(connection and" + "Invalid request, Either integration or \\(connection and" " \\(entity_operations or actions\\)\\) should be provided." ), ): @@ -259,25 +329,16 @@ def test_initialization_without_required_params(project, location): with pytest.raises( ValueError, match=( - "Either \\(integration and trigger\\) or \\(connection and" + "Invalid request, Either integration or \\(connection and" " \\(entity_operations or actions\\)\\) should be provided." ), ): - ApplicationIntegrationToolset(project, location, integration="test") + ApplicationIntegrationToolset(project, location, triggers=["test"]) with pytest.raises( ValueError, match=( - "Either \\(integration and trigger\\) or \\(connection and" - " \\(entity_operations or actions\\)\\) should be provided." - ), - ): - ApplicationIntegrationToolset(project, location, trigger="test") - - with pytest.raises( - ValueError, - match=( - "Either \\(integration and trigger\\) or \\(connection and" + "Invalid request, Either integration or \\(connection and" " \\(entity_operations or actions\\)\\) should be provided." ), ): @@ -305,19 +366,19 @@ def test_initialization_with_service_account_credentials( "universe_domain": "googleapis.com", }) integration_name = "test-integration" - trigger_name = "test-trigger" + triggers = ["test-trigger"] toolset = ApplicationIntegrationToolset( project, location, integration=integration_name, - trigger=trigger_name, + triggers=triggers, service_account_json=service_account_json, ) mock_integration_client.assert_called_once_with( project, location, integration_name, - trigger_name, + triggers, None, None, None, @@ -338,12 +399,12 @@ def test_initialization_without_explicit_service_account_credentials( project, location, mock_integration_client, mock_openapi_toolset ): integration_name = "test-integration" - trigger_name = "test-trigger" + triggers = "test-trigger" toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, trigger=trigger_name + project, location, integration=integration_name, triggers=triggers ) mock_integration_client.assert_called_once_with( - project, location, integration_name, trigger_name, None, None, None, None + project, location, integration_name, triggers, None, None, None, None ) mock_openapi_toolset.assert_called_once() _, kwargs = mock_openapi_toolset.call_args @@ -355,9 +416,9 @@ def test_get_tools( project, location, mock_integration_client, mock_openapi_toolset ): integration_name = "test-integration" - trigger_name = "test-trigger" + triggers = ["test-trigger"] toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, trigger=trigger_name + project, location, integration=integration_name, triggers=triggers ) tools = toolset.get_tools() assert len(tools) == 1