Added support for dynamic auth in integration connector tool

PiperOrigin-RevId: 759676602
This commit is contained in:
Google Team Member 2025-05-16 10:53:23 -07:00 committed by Copybara-Service
parent 2f006264ce
commit 6e0ea01fcb
5 changed files with 370 additions and 10 deletions

View File

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional
from typing import Union
import logging
from typing import List, Optional, Union
from fastapi.openapi.models import HTTPBearer
from typing_extensions import override
@ -24,6 +23,7 @@ from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential
from ...auth.auth_schemes import AuthScheme
from ..base_toolset import BaseToolset
from ..base_toolset import ToolPredicate
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
@ -35,6 +35,9 @@ from .clients.integration_client import IntegrationClient
from .integration_connector_tool import IntegrationConnectorTool
logger = logging.getLogger(__name__)
# TODO(cheliu): Apply a common toolset interface
class ApplicationIntegrationToolset(BaseToolset):
"""ApplicationIntegrationToolset generates tools from a given Application
@ -93,6 +96,8 @@ class ApplicationIntegrationToolset(BaseToolset):
# tool/python function description.
tool_instructions: Optional[str] = "",
service_account_json: Optional[str] = None,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
):
"""Args:
@ -132,6 +137,8 @@ class ApplicationIntegrationToolset(BaseToolset):
self._tool_name_prefix = tool_name_prefix
self._tool_instructions = tool_instructions
self._service_account_json = service_account_json
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self.tool_filter = tool_filter
integration_client = IntegrationClient(
@ -212,6 +219,27 @@ class ApplicationIntegrationToolset(BaseToolset):
rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential)
auth_override_enabled = connection_details.get(
"authOverrideEnabled", False
)
if (
self._auth_scheme
and self._auth_credential
and not auth_override_enabled
):
# Case: Auth provided, but override is OFF. Don't use provided auth.
logger.warning(
"Authentication schema and credentials are not used because"
" authOverrideEnabled is not enabled in the connection."
)
connector_auth_scheme = None
connector_auth_credential = None
else:
connector_auth_scheme = self._auth_scheme
connector_auth_credential = self._auth_credential
self._tools.append(
IntegrationConnectorTool(
name=rest_api_tool.name,
@ -223,6 +251,8 @@ class ApplicationIntegrationToolset(BaseToolset):
action=action,
operation=operation,
rest_api_tool=rest_api_tool,
auth_scheme=connector_auth_scheme,
auth_credential=connector_auth_credential,
)
)

View File

@ -554,6 +554,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
},
}
@ -580,6 +583,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
"filterClause": {"$ref": "#/components/schemas/filterClause"},
},
}
@ -603,6 +609,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
},
}
@ -625,6 +634,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
"filterClause": {"$ref": "#/components/schemas/filterClause"},
},
}
@ -649,6 +661,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
},
}
@ -673,6 +688,9 @@ class ConnectionsClient:
"connectorInputPayload": {
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
},
}

View File

@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from typing import Optional
from typing import Any, Dict, Optional, Union
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from .. import BaseTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from ..openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
from ..tool_context import ToolContext
logger = logging.getLogger(__name__)
@ -56,6 +57,7 @@ class IntegrationConnectorTool(BaseTool):
'entity',
'operation',
'action',
'dynamic_auth_config',
]
OPTIONAL_FIELDS = [
@ -75,6 +77,8 @@ class IntegrationConnectorTool(BaseTool):
operation: str,
action: str,
rest_api_tool: RestApiTool,
auth_scheme: Optional[Union[AuthScheme, str]] = None,
auth_credential: Optional[Union[AuthCredential, str]] = None,
):
"""Initializes the ApplicationIntegrationTool.
@ -108,6 +112,8 @@ class IntegrationConnectorTool(BaseTool):
self._operation = operation
self._action = action
self._rest_api_tool = rest_api_tool
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
@override
def _get_declaration(self) -> FunctionDeclaration:
@ -126,10 +132,45 @@ class IntegrationConnectorTool(BaseTool):
)
return function_decl
def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str:
if (
auth_credential
and auth_credential.http
and auth_credential.http.credentials
and auth_credential.http.credentials.token
):
return auth_credential.http.credentials.token
return None
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
tool_auth_handler = ToolAuthHandler.from_tool_context(
tool_context, self._auth_scheme, self._auth_credential
)
auth_result = tool_auth_handler.prepare_auth_credentials()
if auth_result.state == 'pending':
return {
'pending': True,
'message': 'Needs your authorization to access your data.',
}
# Attach parameters from auth into main parameters list
if auth_result.auth_credential:
# Attach parameters from auth into main parameters list
auth_credential_token = self._prepare_dynamic_euc(
auth_result.auth_credential
)
if auth_credential_token:
args['dynamic_auth_config'] = {
'oauth2_auth_code_flow.access_token': auth_credential_token
}
else:
args['dynamic_auth_config'] = {'oauth2_auth_code_flow.access_token': {}}
args['connection_name'] = self._connection_name
args['service_name'] = self._connection_service_name
args['host'] = self._connection_host

View File

@ -16,9 +16,12 @@ import json
from unittest import mock
from fastapi.openapi.models import Operation
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.auth import AuthCredentialTypes
from google.adk.auth import OAuth2Auth
from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest
@ -162,6 +165,16 @@ def connection_details():
}
@pytest.fixture
def connection_details_auth_override_enabled():
return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
"authOverrideEnabled": True,
}
@pytest.mark.asyncio
async def test_initialization_with_integration_and_trigger(
project,
@ -474,3 +487,139 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
def test_init_with_connection_and_custom_auth(
mock_integration_client,
mock_connections_client,
mock_openapi_action_spec_parser,
connection_details_auth_override_enabled,
):
connection_name = "test-connection"
actions_list = ["create", "delete"]
tool_name = "My Actions Tool"
tool_instructions = "Perform actions using this tool."
mock_connections_client.return_value.get_connection_details.return_value = (
connection_details_auth_override_enabled
)
oauth2_data_google_cloud = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://test-url/o/oauth2/auth",
"tokenUrl": "https://test-url/token",
"scopes": {
"https://test-url/auth/test-scope": "test scope",
"https://www.test-url.com/auth/test-scope2": "test scope 2",
},
}
},
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test-client-id",
client_secret="test-client-secret",
),
)
toolset = ApplicationIntegrationToolset(
project,
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_instructions=tool_instructions,
auth_scheme=oauth2_scheme,
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
assert toolset.get_tools()[0].auth_scheme == oauth2_scheme
assert toolset.get_tools()[0].auth_credential == auth_credential
def test_init_with_connection_with_auth_override_disabled_and_custom_auth(
mock_integration_client,
mock_connections_client,
mock_openapi_action_spec_parser,
connection_details,
):
connection_name = "test-connection"
actions_list = ["create", "delete"]
tool_name = "My Actions Tool"
tool_instructions = "Perform actions using this tool."
mock_connections_client.return_value.get_connection_details.return_value = (
connection_details
)
oauth2_data_google_cloud = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://test-url/o/oauth2/auth",
"tokenUrl": "https://test-url/token",
"scopes": {
"https://test-url/auth/test-scope": "test scope",
"https://www.test-url.com/auth/test-scope2": "test scope 2",
},
}
},
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test-client-id",
client_secret="test-client-secret",
),
)
toolset = ApplicationIntegrationToolset(
project,
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_instructions=tool_instructions,
auth_scheme=oauth2_scheme,
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
assert not toolset.get_tools()[0].auth_scheme
assert not toolset.get_tools()[0].auth_credential

View File

@ -14,11 +14,14 @@
from unittest import mock
from google.adk.auth import AuthCredential
from google.adk.auth import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type
import pytest
@ -67,6 +70,30 @@ def integration_tool(mock_rest_api_tool):
)
@pytest.fixture
def integration_tool_with_auth(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
auth_scheme=None,
auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="mocked_token"),
),
),
)
def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration()
@ -123,3 +150,98 @@ async def test_run_async(integration_tool, mock_rest_api_tool):
# Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"}
@pytest.mark.asyncio
async def test_run_with_auth_async_none_token(
integration_tool_with_auth, mock_rest_api_tool
):
"""Tests run_async when auth credential token is None."""
input_args = {"user_id": "user456", "filter": "some_filter"}
expected_call_args = {
"user_id": "user456",
"filter": "some_filter",
"dynamic_auth_config": {"oauth2_auth_code_flow.access_token": {}},
"connection_name": "test-conn",
"service_name": "test-service",
"host": "test.example.com",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
with mock.patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = mock.MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
# Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None
mock_auth_credential_without_token = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=None), # Token is None
),
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = (
mock_auth_credential_without_token
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
result = await integration_tool_with_auth.run_async(
args=input_args, tool_context={}
)
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}
@pytest.mark.asyncio
async def test_run_with_auth_async(
integration_tool_with_auth, mock_rest_api_tool
):
"""Tests the async execution with auth delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"dynamic_auth_config": {
"oauth2_auth_code_flow.access_token": "mocked_token"
},
"connection_name": "test-conn",
"service_name": "test-service",
"host": "test.example.com",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
with mock.patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = mock.MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="mocked_token"),
),
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
result = await integration_tool_with_auth.run_async(
args=input_args, tool_context={}
)
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}