Added support for dynamic auth in integration connector tool

PiperOrigin-RevId: 759676602
This commit is contained in:
Google Team Member
2025-05-16 10:53:23 -07:00
committed by Copybara-Service
parent 2f006264ce
commit 6e0ea01fcb
5 changed files with 370 additions and 10 deletions

View File

@@ -16,9 +16,12 @@ import json
from unittest import mock
from fastapi.openapi.models import Operation
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.auth import AuthCredentialTypes
from google.adk.auth import OAuth2Auth
from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest
@@ -162,6 +165,16 @@ def connection_details():
}
@pytest.fixture
def connection_details_auth_override_enabled():
return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
"authOverrideEnabled": True,
}
@pytest.mark.asyncio
async def test_initialization_with_integration_and_trigger(
project,
@@ -474,3 +487,139 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
def test_init_with_connection_and_custom_auth(
mock_integration_client,
mock_connections_client,
mock_openapi_action_spec_parser,
connection_details_auth_override_enabled,
):
connection_name = "test-connection"
actions_list = ["create", "delete"]
tool_name = "My Actions Tool"
tool_instructions = "Perform actions using this tool."
mock_connections_client.return_value.get_connection_details.return_value = (
connection_details_auth_override_enabled
)
oauth2_data_google_cloud = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://test-url/o/oauth2/auth",
"tokenUrl": "https://test-url/token",
"scopes": {
"https://test-url/auth/test-scope": "test scope",
"https://www.test-url.com/auth/test-scope2": "test scope 2",
},
}
},
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test-client-id",
client_secret="test-client-secret",
),
)
toolset = ApplicationIntegrationToolset(
project,
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_instructions=tool_instructions,
auth_scheme=oauth2_scheme,
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
assert toolset.get_tools()[0].auth_scheme == oauth2_scheme
assert toolset.get_tools()[0].auth_credential == auth_credential
def test_init_with_connection_with_auth_override_disabled_and_custom_auth(
mock_integration_client,
mock_connections_client,
mock_openapi_action_spec_parser,
connection_details,
):
connection_name = "test-connection"
actions_list = ["create", "delete"]
tool_name = "My Actions Tool"
tool_instructions = "Perform actions using this tool."
mock_connections_client.return_value.get_connection_details.return_value = (
connection_details
)
oauth2_data_google_cloud = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://test-url/o/oauth2/auth",
"tokenUrl": "https://test-url/token",
"scopes": {
"https://test-url/auth/test-scope": "test scope",
"https://www.test-url.com/auth/test-scope2": "test scope 2",
},
}
},
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test-client-id",
client_secret="test-client-secret",
),
)
toolset = ApplicationIntegrationToolset(
project,
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_instructions=tool_instructions,
auth_scheme=oauth2_scheme,
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
assert not toolset.get_tools()[0].auth_scheme
assert not toolset.get_tools()[0].auth_credential

View File

@@ -14,11 +14,14 @@
from unittest import mock
from google.adk.auth import AuthCredential
from google.adk.auth import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type
import pytest
@@ -67,6 +70,30 @@ def integration_tool(mock_rest_api_tool):
)
@pytest.fixture
def integration_tool_with_auth(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
auth_scheme=None,
auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="mocked_token"),
),
),
)
def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration()
@@ -123,3 +150,98 @@ async def test_run_async(integration_tool, mock_rest_api_tool):
# Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"}
@pytest.mark.asyncio
async def test_run_with_auth_async_none_token(
integration_tool_with_auth, mock_rest_api_tool
):
"""Tests run_async when auth credential token is None."""
input_args = {"user_id": "user456", "filter": "some_filter"}
expected_call_args = {
"user_id": "user456",
"filter": "some_filter",
"dynamic_auth_config": {"oauth2_auth_code_flow.access_token": {}},
"connection_name": "test-conn",
"service_name": "test-service",
"host": "test.example.com",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
with mock.patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = mock.MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
# Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None
mock_auth_credential_without_token = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=None), # Token is None
),
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = (
mock_auth_credential_without_token
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
result = await integration_tool_with_auth.run_async(
args=input_args, tool_context={}
)
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}
@pytest.mark.asyncio
async def test_run_with_auth_async(
integration_tool_with_auth, mock_rest_api_tool
):
"""Tests the async execution with auth delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"dynamic_auth_config": {
"oauth2_auth_code_flow.access_token": "mocked_token"
},
"connection_name": "test-conn",
"service_name": "test-service",
"host": "test.example.com",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
with mock.patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = mock.MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="mocked_token"),
),
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
result = await integration_tool_with_auth.run_async(
args=input_args, tool_context={}
)
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}