ADK changes

PiperOrigin-RevId: 750763037
This commit is contained in:
Google Team Member
2025-04-23 16:28:22 -07:00
committed by Copybara-Service
parent a49d339251
commit ca993277de
6 changed files with 435 additions and 105 deletions

View File

@@ -14,10 +14,12 @@
import json
from unittest import mock
from fastapi.openapi.models import Operation
from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest
@@ -50,6 +52,59 @@ def mock_openapi_toolset():
yield mock_toolset
def get_mocked_parsed_operation(operation_id, attributes):
mock_openapi_spec_parser_instance = mock.MagicMock()
mock_parsed_operation = mock.MagicMock(spec=ParsedOperation)
mock_parsed_operation.name = "list_issues"
mock_parsed_operation.description = "list_issues_description"
mock_parsed_operation.endpoint = OperationEndpoint(
base_url="http://localhost:8080",
path="/v1/issues",
method="GET",
)
mock_parsed_operation.auth_scheme = None
mock_parsed_operation.auth_credential = None
mock_parsed_operation.additional_context = {}
mock_parsed_operation.parameters = []
mock_operation = mock.MagicMock(spec=Operation)
mock_operation.operationId = operation_id
mock_operation.description = "list_issues_description"
mock_operation.parameters = []
mock_operation.requestBody = None
mock_operation.responses = {}
mock_operation.callbacks = {}
for key, value in attributes.items():
setattr(mock_operation, key, value)
mock_parsed_operation.operation = mock_operation
mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation]
return mock_openapi_spec_parser_instance
@pytest.fixture
def mock_openapi_entity_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_spec_parser_instance = get_mocked_parsed_operation(
"list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"}
)
mock_spec_parser.return_value = mock_openapi_spec_parser_instance
yield mock_spec_parser
@pytest.fixture
def mock_openapi_action_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation(
"list_issues_operation",
{"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"},
)
mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance
yield mock_spec_parser
@pytest.fixture
def project():
return "test-project"
@@ -72,7 +127,11 @@ def connection_spec():
@pytest.fixture
def connection_details():
return {"serviceName": "test-service", "host": "test.host"}
return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
}
def test_initialization_with_integration_and_trigger(
@@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations(
location,
mock_integration_client,
mock_connections_client,
mock_openapi_toolset,
mock_openapi_entity_spec_parser,
connection_details,
):
connection_name = "test-connection"
@@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations(
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_openapi_entity_spec_parser.return_value.parse.assert_called_once()
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
tool_instructions,
)
mock_openapi_toolset.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool"
assert toolset.get_tools()[0].name == "list_issues"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].entity == "Issues"
assert toolset.get_tools()[0].operation == "LIST_ENTITIES"
def test_initialization_with_connection_and_actions(
@@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions(
location,
mock_integration_client,
mock_connections_client,
mock_openapi_toolset,
mock_openapi_action_spec_parser,
connection_details,
):
connection_name = "test-connection"
@@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
)
mock_openapi_toolset.assert_called_once()
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool"
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
def test_initialization_without_required_params(project, location):
@@ -337,9 +392,4 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
" connection name ="
" projects/test-project/locations/us-central1/connections/test-connection"
" when using this tool. DONOT ask the user for these values as you"
" already have those.",
)

View File

@@ -0,0 +1,125 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type
import pytest
@pytest.fixture
def mock_rest_api_tool():
"""Fixture for a mocked RestApiTool."""
mock_tool = mock.MagicMock(spec=RestApiTool)
mock_tool.name = "mock_rest_tool"
mock_tool.description = "Mock REST tool description."
# Mock the internal parser needed for _get_declaration
mock_parser = mock.MagicMock()
mock_parser.get_json_schema.return_value = {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "User ID"},
"connection_name": {"type": "string"},
"host": {"type": "string"},
"service_name": {"type": "string"},
"entity": {"type": "string"},
"operation": {"type": "string"},
"action": {"type": "string"},
"page_size": {"type": "integer"},
"filter": {"type": "string"},
},
"required": ["user_id", "page_size", "filter", "connection_name"],
}
mock_tool._operation_parser = mock_parser
mock_tool.call.return_value = {"status": "success", "data": "mock_data"}
return mock_tool
@pytest.fixture
def integration_tool(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
)
def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_integration_tool"
assert declaration.description == "Test integration tool description."
# Check parameters schema
params = declaration.parameters
assert isinstance(params, Schema)
print(f"params: {params}")
assert params.type == Type.OBJECT
# Check properties (excluded fields should not be present)
assert "user_id" in params.properties
assert "connection_name" not in params.properties
assert "host" not in params.properties
assert "service_name" not in params.properties
assert "entity" not in params.properties
assert "operation" not in params.properties
assert "action" not in params.properties
assert "page_size" in params.properties
assert "filter" in params.properties
# Check required fields (optional and excluded fields should not be required)
assert "user_id" in params.required
assert "page_size" not in params.required
assert "filter" not in params.required
assert "connection_name" not in params.required
@pytest.mark.asyncio
async def test_run_async(integration_tool, mock_rest_api_tool):
"""Tests the async execution delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"connection_name": "test-conn",
"host": "test.example.com",
"service_name": "test-service",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
result = await integration_tool.run_async(args=input_args, tool_context=None)
# Assert the underlying rest_api_tool.call was called correctly
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context=None
)
# Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"}