mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
ADK changes
PiperOrigin-RevId: 750763037
This commit is contained in:
committed by
Copybara-Service
parent
a49d339251
commit
ca993277de
@@ -14,10 +14,12 @@
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
|
||||
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -50,6 +52,59 @@ def mock_openapi_toolset():
|
||||
yield mock_toolset
|
||||
|
||||
|
||||
def get_mocked_parsed_operation(operation_id, attributes):
|
||||
mock_openapi_spec_parser_instance = mock.MagicMock()
|
||||
mock_parsed_operation = mock.MagicMock(spec=ParsedOperation)
|
||||
mock_parsed_operation.name = "list_issues"
|
||||
mock_parsed_operation.description = "list_issues_description"
|
||||
mock_parsed_operation.endpoint = OperationEndpoint(
|
||||
base_url="http://localhost:8080",
|
||||
path="/v1/issues",
|
||||
method="GET",
|
||||
)
|
||||
mock_parsed_operation.auth_scheme = None
|
||||
mock_parsed_operation.auth_credential = None
|
||||
mock_parsed_operation.additional_context = {}
|
||||
mock_parsed_operation.parameters = []
|
||||
mock_operation = mock.MagicMock(spec=Operation)
|
||||
mock_operation.operationId = operation_id
|
||||
mock_operation.description = "list_issues_description"
|
||||
mock_operation.parameters = []
|
||||
mock_operation.requestBody = None
|
||||
mock_operation.responses = {}
|
||||
mock_operation.callbacks = {}
|
||||
for key, value in attributes.items():
|
||||
setattr(mock_operation, key, value)
|
||||
mock_parsed_operation.operation = mock_operation
|
||||
mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation]
|
||||
return mock_openapi_spec_parser_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openapi_entity_spec_parser():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
|
||||
) as mock_spec_parser:
|
||||
mock_openapi_spec_parser_instance = get_mocked_parsed_operation(
|
||||
"list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"}
|
||||
)
|
||||
mock_spec_parser.return_value = mock_openapi_spec_parser_instance
|
||||
yield mock_spec_parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openapi_action_spec_parser():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
|
||||
) as mock_spec_parser:
|
||||
mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation(
|
||||
"list_issues_operation",
|
||||
{"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"},
|
||||
)
|
||||
mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance
|
||||
yield mock_spec_parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
@@ -72,7 +127,11 @@ def connection_spec():
|
||||
|
||||
@pytest.fixture
|
||||
def connection_details():
|
||||
return {"serviceName": "test-service", "host": "test.host"}
|
||||
return {
|
||||
"serviceName": "test-service",
|
||||
"host": "test.host",
|
||||
"name": "test-connection",
|
||||
}
|
||||
|
||||
|
||||
def test_initialization_with_integration_and_trigger(
|
||||
@@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations(
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
mock_openapi_entity_spec_parser,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
@@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations(
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_openapi_entity_spec_parser.return_value.parse.assert_called_once()
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
||||
f" {connection_details['host']} and the connection name ="
|
||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
||||
" using this tool. DONOT ask the user for these values as you already"
|
||||
" have those.",
|
||||
tool_instructions,
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
assert toolset.get_tools()[0].name == "list_issues"
|
||||
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||
assert toolset.get_tools()[0].entity == "Issues"
|
||||
assert toolset.get_tools()[0].operation == "LIST_ENTITIES"
|
||||
|
||||
|
||||
def test_initialization_with_connection_and_actions(
|
||||
@@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions(
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
mock_openapi_action_spec_parser,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
@@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions(
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
||||
f" {connection_details['host']} and the connection name ="
|
||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
||||
" using this tool. DONOT ask the user for these values as you already"
|
||||
" have those.",
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
assert toolset.get_tools()[0].name == "list_issues_operation"
|
||||
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||
assert toolset.get_tools()[0].action == "CustomAction"
|
||||
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
|
||||
|
||||
|
||||
def test_initialization_without_required_params(project, location):
|
||||
@@ -337,9 +392,4 @@ def test_initialization_with_connection_details(
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
|
||||
" connection name ="
|
||||
" projects/test-project/locations/us-central1/connections/test-connection"
|
||||
" when using this tool. DONOT ask the user for these values as you"
|
||||
" already have those.",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
from google.genai.types import Tool
|
||||
from google.genai.types import Type
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rest_api_tool():
|
||||
"""Fixture for a mocked RestApiTool."""
|
||||
mock_tool = mock.MagicMock(spec=RestApiTool)
|
||||
mock_tool.name = "mock_rest_tool"
|
||||
mock_tool.description = "Mock REST tool description."
|
||||
# Mock the internal parser needed for _get_declaration
|
||||
mock_parser = mock.MagicMock()
|
||||
mock_parser.get_json_schema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_id": {"type": "string", "description": "User ID"},
|
||||
"connection_name": {"type": "string"},
|
||||
"host": {"type": "string"},
|
||||
"service_name": {"type": "string"},
|
||||
"entity": {"type": "string"},
|
||||
"operation": {"type": "string"},
|
||||
"action": {"type": "string"},
|
||||
"page_size": {"type": "integer"},
|
||||
"filter": {"type": "string"},
|
||||
},
|
||||
"required": ["user_id", "page_size", "filter", "connection_name"],
|
||||
}
|
||||
mock_tool._operation_parser = mock_parser
|
||||
mock_tool.call.return_value = {"status": "success", "data": "mock_data"}
|
||||
return mock_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_tool(mock_rest_api_tool):
|
||||
"""Fixture for an IntegrationConnectorTool instance."""
|
||||
return IntegrationConnectorTool(
|
||||
name="test_integration_tool",
|
||||
description="Test integration tool description.",
|
||||
connection_name="test-conn",
|
||||
connection_host="test.example.com",
|
||||
connection_service_name="test-service",
|
||||
entity="TestEntity",
|
||||
operation="LIST",
|
||||
action="TestAction",
|
||||
rest_api_tool=mock_rest_api_tool,
|
||||
)
|
||||
|
||||
|
||||
def test_get_declaration(integration_tool):
|
||||
"""Tests the generation of the function declaration."""
|
||||
declaration = integration_tool._get_declaration()
|
||||
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
assert declaration.name == "test_integration_tool"
|
||||
assert declaration.description == "Test integration tool description."
|
||||
|
||||
# Check parameters schema
|
||||
params = declaration.parameters
|
||||
assert isinstance(params, Schema)
|
||||
print(f"params: {params}")
|
||||
assert params.type == Type.OBJECT
|
||||
|
||||
# Check properties (excluded fields should not be present)
|
||||
assert "user_id" in params.properties
|
||||
assert "connection_name" not in params.properties
|
||||
assert "host" not in params.properties
|
||||
assert "service_name" not in params.properties
|
||||
assert "entity" not in params.properties
|
||||
assert "operation" not in params.properties
|
||||
assert "action" not in params.properties
|
||||
assert "page_size" in params.properties
|
||||
assert "filter" in params.properties
|
||||
|
||||
# Check required fields (optional and excluded fields should not be required)
|
||||
assert "user_id" in params.required
|
||||
assert "page_size" not in params.required
|
||||
assert "filter" not in params.required
|
||||
assert "connection_name" not in params.required
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(integration_tool, mock_rest_api_tool):
|
||||
"""Tests the async execution delegates correctly to the RestApiTool."""
|
||||
input_args = {"user_id": "user123", "page_size": 10}
|
||||
expected_call_args = {
|
||||
"user_id": "user123",
|
||||
"page_size": 10,
|
||||
"connection_name": "test-conn",
|
||||
"host": "test.example.com",
|
||||
"service_name": "test-service",
|
||||
"entity": "TestEntity",
|
||||
"operation": "LIST",
|
||||
"action": "TestAction",
|
||||
}
|
||||
|
||||
result = await integration_tool.run_async(args=input_args, tool_context=None)
|
||||
|
||||
# Assert the underlying rest_api_tool.call was called correctly
|
||||
mock_rest_api_tool.call.assert_called_once_with(
|
||||
args=expected_call_args, tool_context=None
|
||||
)
|
||||
|
||||
# Assert the result is what the mocked call returned
|
||||
assert result == {"status": "success", "data": "mock_data"}
|
||||
Reference in New Issue
Block a user