diff --git a/src/google/adk/tools/application_integration_tool/__init__.py b/src/google/adk/tools/application_integration_tool/__init__.py index fd9eb51..23c9b56 100644 --- a/src/google/adk/tools/application_integration_tool/__init__.py +++ b/src/google/adk/tools/application_integration_tool/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. from .application_integration_toolset import ApplicationIntegrationToolset +from .integration_connector_tool import IntegrationConnectorTool __all__ = [ 'ApplicationIntegrationToolset', + 'IntegrationConnectorTool', ] diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index 5874bb5..d904de4 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict -from typing import List -from typing import Optional +from typing import Dict, List, Optional from fastapi.openapi.models import HTTPBearer -from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient -from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient -from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential -from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset -from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool from ...auth.auth_credential import AuthCredential from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import ServiceAccount from ...auth.auth_credential import ServiceAccountCredential +from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential +from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser +from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset +from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from .clients.connections_client import ConnectionsClient +from .clients.integration_client import IntegrationClient +from .integration_connector_tool import IntegrationConnectorTool # TODO(cheliu): Apply a common toolset interface @@ -168,6 +168,7 @@ class ApplicationIntegrationToolset: actions, service_account_json, ) + connection_details = {} if integration and trigger: spec = integration_client.get_openapi_spec_for_integration() elif connection and (entity_operations or actions): @@ -175,16 +176,6 @@ class ApplicationIntegrationToolset: project, location, connection, service_account_json ) connection_details = connections_client.get_connection_details() - tool_instructions += ( - "ALWAYS use serviceName = " - + connection_details["serviceName"] - + ", host = " - + connection_details["host"] - + " and the connection name = " - + f"projects/{project}/locations/{location}/connections/{connection} when" - " using this tool" - + ". DONOT ask the user for these values as you already have those." - ) spec = integration_client.get_openapi_spec_for_connection( tool_name, tool_instructions, @@ -194,9 +185,9 @@ class ApplicationIntegrationToolset: "Either (integration and trigger) or (connection and" " (entity_operations or actions)) should be provided." ) - self._parse_spec_to_tools(spec) + self._parse_spec_to_tools(spec, connection_details) - def _parse_spec_to_tools(self, spec_dict): + def _parse_spec_to_tools(self, spec_dict, connection_details): """Parses the spec dict to a list of RestApiTool.""" if self.service_account_json: sa_credential = ServiceAccountCredential.model_validate_json( @@ -218,12 +209,43 @@ class ApplicationIntegrationToolset: ), ) auth_scheme = HTTPBearer(bearerFormat="JWT") - tools = OpenAPIToolset( - spec_dict=spec_dict, - auth_credential=auth_credential, - auth_scheme=auth_scheme, - ).get_tools() - for tool in tools: + + if self.integration and self.trigger: + tools = OpenAPIToolset( + spec_dict=spec_dict, + auth_credential=auth_credential, + auth_scheme=auth_scheme, + ).get_tools() + for tool in tools: + self.generated_tools[tool.name] = tool + return + + operations = OpenApiSpecParser().parse(spec_dict) + + for open_api_operation in operations: + operation = getattr(open_api_operation.operation, "x-operation") + entity = None + action = None + if hasattr(open_api_operation.operation, "x-entity"): + entity = getattr(open_api_operation.operation, "x-entity") + elif hasattr(open_api_operation.operation, "x-action"): + action = getattr(open_api_operation.operation, "x-action") + rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation) + if auth_scheme: + rest_api_tool.configure_auth_scheme(auth_scheme) + if auth_credential: + rest_api_tool.configure_auth_credential(auth_credential) + tool = IntegrationConnectorTool( + name=rest_api_tool.name, + description=rest_api_tool.description, + connection_name=connection_details["name"], + connection_host=connection_details["host"], + connection_service_name=connection_details["serviceName"], + entity=entity, + action=action, + operation=operation, + rest_api_tool=rest_api_tool, + ) self.generated_tools[tool.name] = tool def get_tools(self) -> List[RestApiTool]: diff --git a/src/google/adk/tools/application_integration_tool/clients/connections_client.py b/src/google/adk/tools/application_integration_tool/clients/connections_client.py index 06b4acf..2fbe2f6 100644 --- a/src/google/adk/tools/application_integration_tool/clients/connections_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/connections_client.py @@ -68,12 +68,14 @@ class ConnectionsClient: response = self._execute_api_call(url) connection_data = response.json() + connection_name = connection_data.get("name", "") service_name = connection_data.get("serviceDirectory", "") host = connection_data.get("host", "") if host: service_name = connection_data.get("tlsServiceDirectory", "") auth_override_enabled = connection_data.get("authOverrideEnabled", False) return { + "name": connection_name, "serviceName": service_name, "host": host, "authOverrideEnabled": auth_override_enabled, @@ -291,13 +293,9 @@ class ConnectionsClient: tool_name: str = "", tool_instructions: str = "", ) -> Dict[str, Any]: - description = ( - f"Use this tool with" f' action = "{action}" and' - ) + f' operation = "{operation}" only. Dont ask these values from user.' + description = f"Use this tool to execute {action}" if operation == "EXECUTE_QUERY": - description = ( - (f"Use this tool with" f' action = "{action}" and') - + f' operation = "{operation}" only. Dont ask these values from user.' + description += ( " Use pageSize = 50 and timeout = 120 until user specifies a" " different value otherwise. If user provides a query in natural" " language, convert it to SQL query and then execute it using the" @@ -308,6 +306,8 @@ class ConnectionsClient: "summary": f"{action_display_name}", "description": f"{description} {tool_instructions}", "operationId": f"{tool_name}_{action_display_name}", + "x-action": f"{action}", + "x-operation": f"{operation}", "requestBody": { "content": { "application/json": { @@ -347,16 +347,12 @@ class ConnectionsClient: "post": { "summary": f"List {entity}", "description": ( - f"Returns all entities of type {entity}. Use this tool with" - + f' entity = "{entity}" and' - + ' operation = "LIST_ENTITIES" only. Dont ask these values' - " from" - + ' user. Always use ""' - + ' as filter clause and ""' - + " as page token and 50 as page size until user specifies a" - " different value otherwise. Use single quotes for strings in" - f" filter clause. {tool_instructions}" + f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the + following format: `field_name1='value1' AND field_name2='value2' + `. {tool_instructions}""" ), + "x-operation": "LIST_ENTITIES", + "x-entity": f"{entity}", "operationId": f"{tool_name}_list_{entity}", "requestBody": { "content": { @@ -401,14 +397,11 @@ class ConnectionsClient: "post": { "summary": f"Get {entity}", "description": ( - ( - f"Returns the details of the {entity}. Use this tool with" - f' entity = "{entity}" and' - ) - + ' operation = "GET_ENTITY" only. Dont ask these values from' - f" user. {tool_instructions}" + f"Returns the details of the {entity}. {tool_instructions}" ), "operationId": f"{tool_name}_get_{entity}", + "x-operation": "GET_ENTITY", + "x-entity": f"{entity}", "requestBody": { "content": { "application/json": { @@ -445,17 +438,10 @@ class ConnectionsClient: ) -> Dict[str, Any]: return { "post": { - "summary": f"Create {entity}", - "description": ( - ( - f"Creates a new entity of type {entity}. Use this tool with" - f' entity = "{entity}" and' - ) - + ' operation = "CREATE_ENTITY" only. Dont ask these values' - " from" - + " user. Follow the schema of the entity provided in the" - f" instructions to create {entity}. {tool_instructions}" - ), + "summary": f"Creates a new {entity}", + "description": f"Creates a new {entity}. {tool_instructions}", + "x-operation": "CREATE_ENTITY", + "x-entity": f"{entity}", "operationId": f"{tool_name}_create_{entity}", "requestBody": { "content": { @@ -491,18 +477,10 @@ class ConnectionsClient: ) -> Dict[str, Any]: return { "post": { - "summary": f"Update {entity}", - "description": ( - ( - f"Updates an entity of type {entity}. Use this tool with" - f' entity = "{entity}" and' - ) - + ' operation = "UPDATE_ENTITY" only. Dont ask these values' - " from" - + " user. Use entityId to uniquely identify the entity to" - " update. Follow the schema of the entity provided in the" - f" instructions to update {entity}. {tool_instructions}" - ), + "summary": f"Updates the {entity}", + "description": f"Updates the {entity}. {tool_instructions}", + "x-operation": "UPDATE_ENTITY", + "x-entity": f"{entity}", "operationId": f"{tool_name}_update_{entity}", "requestBody": { "content": { @@ -538,16 +516,10 @@ class ConnectionsClient: ) -> Dict[str, Any]: return { "post": { - "summary": f"Delete {entity}", - "description": ( - ( - f"Deletes an entity of type {entity}. Use this tool with" - f' entity = "{entity}" and' - ) - + ' operation = "DELETE_ENTITY" only. Dont ask these values' - " from" - f" user. {tool_instructions}" - ), + "summary": f"Delete the {entity}", + "description": f"Deletes the {entity}. {tool_instructions}", + "x-operation": "DELETE_ENTITY", + "x-entity": f"{entity}", "operationId": f"{tool_name}_delete_{entity}", "requestBody": { "content": { diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py new file mode 100644 index 0000000..2513da5 --- /dev/null +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -0,0 +1,159 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import Any +from typing import Dict +from typing import Optional + +from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema +from google.genai.types import FunctionDeclaration +from typing_extensions import override + +from .. import BaseTool +from ..tool_context import ToolContext + +logger = logging.getLogger(__name__) + + +class IntegrationConnectorTool(BaseTool): + """A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint. + + This tool adds Application Integration specific context like connection + details, entity, operation, and action to the underlying REST API call + handled by RestApiTool. It prepares the arguments and then delegates the + actual API call execution to the contained RestApiTool instance. + + * Generates request params and body + * Attaches auth credentials to API call. + + Example: + ``` + # Each API operation in the spec will be turned into its own tool + # Name of the tool is the operationId of that operation, in snake case + operations = OperationGenerator().parse(openapi_spec_dict) + tool = [RestApiTool.from_parsed_operation(o) for o in operations] + ``` + """ + + EXCLUDE_FIELDS = [ + 'connection_name', + 'service_name', + 'host', + 'entity', + 'operation', + 'action', + ] + + OPTIONAL_FIELDS = [ + 'page_size', + 'page_token', + 'filter', + ] + + def __init__( + self, + name: str, + description: str, + connection_name: str, + connection_host: str, + connection_service_name: str, + entity: str, + operation: str, + action: str, + rest_api_tool: RestApiTool, + ): + """Initializes the ApplicationIntegrationTool. + + Args: + name: The name of the tool, typically derived from the API operation. + Should be unique and adhere to Gemini function naming conventions + (e.g., less than 64 characters). + description: A description of what the tool does, usually based on the + API operation's summary or description. + connection_name: The name of the Integration Connector connection. + connection_host: The hostname or IP address for the connection. + connection_service_name: The specific service name within the host. + entity: The Integration Connector entity being targeted. + operation: The specific operation being performed on the entity. + action: The action associated with the operation (e.g., 'execute'). + rest_api_tool: An initialized RestApiTool instance that handles the + underlying REST API communication based on an OpenAPI specification + operation. This tool will be called by ApplicationIntegrationTool with + added connection and context arguments. tool = + [RestApiTool.from_parsed_operation(o) for o in operations] + """ + # Gemini restrict the length of function name to be less than 64 characters + super().__init__( + name=name, + description=description, + ) + self.connection_name = connection_name + self.connection_host = connection_host + self.connection_service_name = connection_service_name + self.entity = entity + self.operation = operation + self.action = action + self.rest_api_tool = rest_api_tool + + @override + def _get_declaration(self) -> FunctionDeclaration: + """Returns the function declaration in the Gemini Schema format.""" + schema_dict = self.rest_api_tool._operation_parser.get_json_schema() + for field in self.EXCLUDE_FIELDS: + if field in schema_dict['properties']: + del schema_dict['properties'][field] + for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS: + if field in schema_dict['required']: + schema_dict['required'].remove(field) + + parameters = to_gemini_schema(schema_dict) + function_decl = FunctionDeclaration( + name=self.name, description=self.description, parameters=parameters + ) + return function_decl + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: Optional[ToolContext] + ) -> Dict[str, Any]: + args['connection_name'] = self.connection_name + args['service_name'] = self.connection_service_name + args['host'] = self.connection_host + args['entity'] = self.entity + args['operation'] = self.operation + args['action'] = self.action + logger.info('Running tool: %s with args: %s', self.name, args) + return self.rest_api_tool.call(args=args, tool_context=tool_context) + + def __str__(self): + return ( + f'ApplicationIntegrationTool(name="{self.name}",' + f' description="{self.description}",' + f' connection_name="{self.connection_name}", entity="{self.entity}",' + f' operation="{self.operation}", action="{self.action}")' + ) + + def __repr__(self): + return ( + f'ApplicationIntegrationTool(name="{self.name}",' + f' description="{self.description}",' + f' connection_name="{self.connection_name}",' + f' connection_host="{self.connection_host}",' + f' connection_service_name="{self.connection_service_name}",' + f' entity="{self.entity}", operation="{self.operation}",' + f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})' + ) diff --git a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py index 3a7f6ea..b960dd6 100644 --- a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +++ b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py @@ -14,10 +14,12 @@ import json from unittest import mock - +from fastapi.openapi.models import Operation from google.adk.auth.auth_credential import AuthCredential from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset -from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool +from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool +from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool +from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint import pytest @@ -50,6 +52,59 @@ def mock_openapi_toolset(): yield mock_toolset +def get_mocked_parsed_operation(operation_id, attributes): + mock_openapi_spec_parser_instance = mock.MagicMock() + mock_parsed_operation = mock.MagicMock(spec=ParsedOperation) + mock_parsed_operation.name = "list_issues" + mock_parsed_operation.description = "list_issues_description" + mock_parsed_operation.endpoint = OperationEndpoint( + base_url="http://localhost:8080", + path="/v1/issues", + method="GET", + ) + mock_parsed_operation.auth_scheme = None + mock_parsed_operation.auth_credential = None + mock_parsed_operation.additional_context = {} + mock_parsed_operation.parameters = [] + mock_operation = mock.MagicMock(spec=Operation) + mock_operation.operationId = operation_id + mock_operation.description = "list_issues_description" + mock_operation.parameters = [] + mock_operation.requestBody = None + mock_operation.responses = {} + mock_operation.callbacks = {} + for key, value in attributes.items(): + setattr(mock_operation, key, value) + mock_parsed_operation.operation = mock_operation + mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation] + return mock_openapi_spec_parser_instance + + +@pytest.fixture +def mock_openapi_entity_spec_parser(): + with mock.patch( + "google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser" + ) as mock_spec_parser: + mock_openapi_spec_parser_instance = get_mocked_parsed_operation( + "list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"} + ) + mock_spec_parser.return_value = mock_openapi_spec_parser_instance + yield mock_spec_parser + + +@pytest.fixture +def mock_openapi_action_spec_parser(): + with mock.patch( + "google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser" + ) as mock_spec_parser: + mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation( + "list_issues_operation", + {"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"}, + ) + mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance + yield mock_spec_parser + + @pytest.fixture def project(): return "test-project" @@ -72,7 +127,11 @@ def connection_spec(): @pytest.fixture def connection_details(): - return {"serviceName": "test-service", "host": "test.host"} + return { + "serviceName": "test-service", + "host": "test.host", + "name": "test-connection", + } def test_initialization_with_integration_and_trigger( @@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations( location, mock_integration_client, mock_connections_client, - mock_openapi_toolset, + mock_openapi_entity_spec_parser, connection_details, ): connection_name = "test-connection" @@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations( mock_connections_client.assert_called_once_with( project, location, connection_name, None ) + mock_openapi_entity_spec_parser.return_value.parse.assert_called_once() mock_connections_client.return_value.get_connection_details.assert_called_once() mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( tool_name, - tool_instructions - + f"ALWAYS use serviceName = {connection_details['serviceName']}, host =" - f" {connection_details['host']} and the connection name =" - f" projects/{project}/locations/{location}/connections/{connection_name} when" - " using this tool. DONOT ask the user for these values as you already" - " have those.", + tool_instructions, ) - mock_openapi_toolset.assert_called_once() assert len(toolset.get_tools()) == 1 - assert toolset.get_tools()[0].name == "Test Tool" + assert toolset.get_tools()[0].name == "list_issues" + assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool) + assert toolset.get_tools()[0].entity == "Issues" + assert toolset.get_tools()[0].operation == "LIST_ENTITIES" def test_initialization_with_connection_and_actions( @@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions( location, mock_integration_client, mock_connections_client, - mock_openapi_toolset, + mock_openapi_action_spec_parser, connection_details, ): connection_name = "test-connection" @@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( tool_name, tool_instructions - + f"ALWAYS use serviceName = {connection_details['serviceName']}, host =" - f" {connection_details['host']} and the connection name =" - f" projects/{project}/locations/{location}/connections/{connection_name} when" - " using this tool. DONOT ask the user for these values as you already" - " have those.", ) - mock_openapi_toolset.assert_called_once() + mock_openapi_action_spec_parser.return_value.parse.assert_called_once() assert len(toolset.get_tools()) == 1 - assert toolset.get_tools()[0].name == "Test Tool" + assert toolset.get_tools()[0].name == "list_issues_operation" + assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool) + assert toolset.get_tools()[0].action == "CustomAction" + assert toolset.get_tools()[0].operation == "EXECUTE_ACTION" def test_initialization_without_required_params(project, location): @@ -337,9 +392,4 @@ def test_initialization_with_connection_details( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( tool_name, tool_instructions - + "ALWAYS use serviceName = custom-service, host = custom.host and the" - " connection name =" - " projects/test-project/locations/us-central1/connections/test-connection" - " when using this tool. DONOT ask the user for these values as you" - " already have those.", ) diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py new file mode 100644 index 0000000..93ed4bc --- /dev/null +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool +from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from google.genai.types import FunctionDeclaration +from google.genai.types import Schema +from google.genai.types import Tool +from google.genai.types import Type +import pytest + + +@pytest.fixture +def mock_rest_api_tool(): + """Fixture for a mocked RestApiTool.""" + mock_tool = mock.MagicMock(spec=RestApiTool) + mock_tool.name = "mock_rest_tool" + mock_tool.description = "Mock REST tool description." + # Mock the internal parser needed for _get_declaration + mock_parser = mock.MagicMock() + mock_parser.get_json_schema.return_value = { + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID"}, + "connection_name": {"type": "string"}, + "host": {"type": "string"}, + "service_name": {"type": "string"}, + "entity": {"type": "string"}, + "operation": {"type": "string"}, + "action": {"type": "string"}, + "page_size": {"type": "integer"}, + "filter": {"type": "string"}, + }, + "required": ["user_id", "page_size", "filter", "connection_name"], + } + mock_tool._operation_parser = mock_parser + mock_tool.call.return_value = {"status": "success", "data": "mock_data"} + return mock_tool + + +@pytest.fixture +def integration_tool(mock_rest_api_tool): + """Fixture for an IntegrationConnectorTool instance.""" + return IntegrationConnectorTool( + name="test_integration_tool", + description="Test integration tool description.", + connection_name="test-conn", + connection_host="test.example.com", + connection_service_name="test-service", + entity="TestEntity", + operation="LIST", + action="TestAction", + rest_api_tool=mock_rest_api_tool, + ) + + +def test_get_declaration(integration_tool): + """Tests the generation of the function declaration.""" + declaration = integration_tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_integration_tool" + assert declaration.description == "Test integration tool description." + + # Check parameters schema + params = declaration.parameters + assert isinstance(params, Schema) + print(f"params: {params}") + assert params.type == Type.OBJECT + + # Check properties (excluded fields should not be present) + assert "user_id" in params.properties + assert "connection_name" not in params.properties + assert "host" not in params.properties + assert "service_name" not in params.properties + assert "entity" not in params.properties + assert "operation" not in params.properties + assert "action" not in params.properties + assert "page_size" in params.properties + assert "filter" in params.properties + + # Check required fields (optional and excluded fields should not be required) + assert "user_id" in params.required + assert "page_size" not in params.required + assert "filter" not in params.required + assert "connection_name" not in params.required + + +@pytest.mark.asyncio +async def test_run_async(integration_tool, mock_rest_api_tool): + """Tests the async execution delegates correctly to the RestApiTool.""" + input_args = {"user_id": "user123", "page_size": 10} + expected_call_args = { + "user_id": "user123", + "page_size": 10, + "connection_name": "test-conn", + "host": "test.example.com", + "service_name": "test-service", + "entity": "TestEntity", + "operation": "LIST", + "action": "TestAction", + } + + result = await integration_tool.run_async(args=input_args, tool_context=None) + + # Assert the underlying rest_api_tool.call was called correctly + mock_rest_api_tool.call.assert_called_once_with( + args=expected_call_args, tool_context=None + ) + + # Assert the result is what the mocked call returned + assert result == {"status": "success", "data": "mock_data"}