diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index 6f71ba2..d904de4 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -76,7 +76,7 @@ class ApplicationIntegrationToolset: project: str, location: str, integration: Optional[str] = None, - triggers: Optional[List[str]] = None, + trigger: Optional[str] = None, connection: Optional[str] = None, entity_operations: Optional[str] = None, actions: Optional[str] = None, @@ -98,7 +98,7 @@ class ApplicationIntegrationToolset: project="test-project", location="us-central1" integration="test-integration", - triggers=["api_trigger/test_trigger"], + trigger="api_trigger/test_trigger", service_account_credentials={...}, ) @@ -130,7 +130,7 @@ class ApplicationIntegrationToolset: project: The GCP project ID. location: The GCP location. integration: The integration name. - triggers: The list of trigger names in the integration. + trigger: The trigger name. connection: The connection name. entity_operations: The entity operations supported by the connection. actions: The actions supported by the connection. @@ -149,7 +149,7 @@ class ApplicationIntegrationToolset: self.project = project self.location = location self.integration = integration - self.triggers = triggers + self.trigger = trigger self.connection = connection self.entity_operations = entity_operations self.actions = actions @@ -162,14 +162,14 @@ class ApplicationIntegrationToolset: project, location, integration, - triggers, + trigger, connection, entity_operations, actions, service_account_json, ) connection_details = {} - if integration: + if integration and trigger: spec = integration_client.get_openapi_spec_for_integration() elif connection and (entity_operations or actions): connections_client = ConnectionsClient( @@ -210,7 +210,7 @@ class ApplicationIntegrationToolset: ) auth_scheme = HTTPBearer(bearerFormat="JWT") - if self.integration: + if self.integration and self.trigger: tools = OpenAPIToolset( spec_dict=spec_dict, auth_credential=auth_credential, diff --git a/src/google/adk/tools/application_integration_tool/clients/integration_client.py b/src/google/adk/tools/application_integration_tool/clients/integration_client.py index 088950e..8030ffa 100644 --- a/src/google/adk/tools/application_integration_tool/clients/integration_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/integration_client.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import List, Optional +from typing import Optional from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient import google.auth from google.auth import default as default_service_credential @@ -35,7 +35,7 @@ class IntegrationClient: project: str, location: str, integration: Optional[str] = None, - triggers: List[str] = None, + trigger: Optional[str] = None, connection: Optional[str] = None, entity_operations: Optional[dict[str, list[str]]] = None, actions: Optional[list[str]] = None, @@ -47,7 +47,7 @@ class IntegrationClient: project: The Google Cloud project ID. location: The Google Cloud location (e.g., us-central1). integration: The integration name. - triggers: The list of trigger IDs for the integration. + trigger: The trigger ID for the integration. connection: The connection name. entity_operations: A dictionary mapping entity names to a list of operations (e.g., LIST, CREATE, UPDATE, DELETE, GET). @@ -59,7 +59,7 @@ class IntegrationClient: self.project = project self.location = location self.integration = integration - self.triggers = triggers + self.trigger = trigger self.connection = connection self.entity_operations = ( entity_operations if entity_operations is not None else {} @@ -88,7 +88,7 @@ class IntegrationClient: "apiTriggerResources": [ { "integrationResource": self.integration, - "triggerId": self.triggers, + "triggerId": [self.trigger], }, ], "fileFormat": "JSON", @@ -109,7 +109,7 @@ class IntegrationClient: raise ValueError( "Invalid request. Please check the provided values of" f" project({self.project}), location({self.location})," - f" integration({self.integration}) and trigger({self.triggers})." + f" integration({self.integration}) and trigger({self.trigger})." ) from e raise ValueError(f"Request error: {e}") from e except Exception as e: diff --git a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py index fc4cc51..28dbb9d 100644 --- a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +++ b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py @@ -52,24 +52,6 @@ def mock_openapi_toolset(): yield mock_toolset -@pytest.fixture -def mock_openapi_toolset_with_multiple_tools_and_no_tools(): - with mock.patch( - "google.adk.tools.application_integration_tool.application_integration_toolset.OpenAPIToolset" - ) as mock_toolset: - mock_toolset_instance = mock.MagicMock() - mock_rest_api_tool = mock.MagicMock(spec=rest_api_tool.RestApiTool) - mock_rest_api_tool.name = "Test Tool" - mock_rest_api_tool_2 = mock.MagicMock(spec=rest_api_tool.RestApiTool) - mock_rest_api_tool_2.name = "Test Tool 2" - mock_toolset_instance.get_tools.return_value = [ - mock_rest_api_tool, - mock_rest_api_tool_2, - ] - mock_toolset.return_value = mock_toolset_instance - yield mock_toolset - - def get_mocked_parsed_operation(operation_id, attributes): mock_openapi_spec_parser_instance = mock.MagicMock() mock_parsed_operation = mock.MagicMock(spec=ParsedOperation) @@ -162,17 +144,10 @@ def test_initialization_with_integration_and_trigger( integration_name = "test-integration" trigger_name = "test-trigger" toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, triggers=[trigger_name] + project, location, integration=integration_name, trigger=trigger_name ) mock_integration_client.assert_called_once_with( - project, - location, - integration_name, - [trigger_name], - None, - None, - None, - None, + project, location, integration_name, trigger_name, None, None, None, None ) mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once() mock_connections_client.assert_not_called() @@ -181,58 +156,6 @@ def test_initialization_with_integration_and_trigger( assert toolset.get_tools()[0].name == "Test Tool" -def test_initialization_with_integration_and_list_of_triggers( - project, - location, - mock_integration_client, - mock_connections_client, - mock_openapi_toolset_with_multiple_tools_and_no_tools, -): - integration_name = "test-integration" - trigger_name = ["test-trigger1", "test-trigger2"] - toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, triggers=trigger_name - ) - mock_integration_client.assert_called_once_with( - project, - location, - integration_name, - trigger_name, - None, - None, - None, - None, - ) - mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once() - mock_connections_client.assert_not_called() - mock_openapi_toolset_with_multiple_tools_and_no_tools.assert_called_once() - assert len(toolset.get_tools()) == 2 - assert toolset.get_tools()[0].name == "Test Tool" - assert toolset.get_tools()[1].name == "Test Tool 2" - - -def test_initialization_with_integration_and_empty_trigger_list( - project, - location, - mock_integration_client, - mock_connections_client, - mock_openapi_toolset_with_multiple_tools_and_no_tools, -): - integration_name = "test-integration" - toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name - ) - mock_integration_client.assert_called_once_with( - project, location, integration_name, None, None, None, None, None - ) - mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once() - mock_connections_client.assert_not_called() - mock_openapi_toolset_with_multiple_tools_and_no_tools.assert_called_once() - assert len(toolset.get_tools()) == 2 - assert toolset.get_tools()[0].name == "Test Tool" - assert toolset.get_tools()[1].name == "Test Tool 2" - - def test_initialization_with_connection_and_entity_operations( project, location, @@ -340,7 +263,16 @@ def test_initialization_without_required_params(project, location): " \\(entity_operations or actions\\)\\) should be provided." ), ): - ApplicationIntegrationToolset(project, location, triggers=["test"]) + ApplicationIntegrationToolset(project, location, integration="test") + + with pytest.raises( + ValueError, + match=( + "Either \\(integration and trigger\\) or \\(connection and" + " \\(entity_operations or actions\\)\\) should be provided." + ), + ): + ApplicationIntegrationToolset(project, location, trigger="test") with pytest.raises( ValueError, @@ -378,14 +310,14 @@ def test_initialization_with_service_account_credentials( project, location, integration=integration_name, - triggers=[trigger_name], + trigger=trigger_name, service_account_json=service_account_json, ) mock_integration_client.assert_called_once_with( project, location, integration_name, - [trigger_name], + trigger_name, None, None, None, @@ -408,17 +340,10 @@ def test_initialization_without_explicit_service_account_credentials( integration_name = "test-integration" trigger_name = "test-trigger" toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, triggers=[trigger_name] + project, location, integration=integration_name, trigger=trigger_name ) mock_integration_client.assert_called_once_with( - project, - location, - integration_name, - [trigger_name], - None, - None, - None, - None, + project, location, integration_name, trigger_name, None, None, None, None ) mock_openapi_toolset.assert_called_once() _, kwargs = mock_openapi_toolset.call_args @@ -432,7 +357,7 @@ def test_get_tools( integration_name = "test-integration" trigger_name = "test-trigger" toolset = ApplicationIntegrationToolset( - project, location, integration=integration_name, triggers=[trigger_name] + project, location, integration=integration_name, trigger=trigger_name ) tools = toolset.get_tools() assert len(tools) == 1