Added support for dynamic auth in integration connector tool

PiperOrigin-RevId: 759676602
This commit is contained in:
Google Team Member 2025-05-16 10:53:23 -07:00 committed by Copybara-Service
parent 2f006264ce
commit 6e0ea01fcb
5 changed files with 370 additions and 10 deletions

View File

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List import logging
from typing import Optional from typing import List, Optional, Union
from typing import Union
from fastapi.openapi.models import HTTPBearer from fastapi.openapi.models import HTTPBearer
from typing_extensions import override from typing_extensions import override
@ -24,6 +23,7 @@ from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential from ...auth.auth_credential import ServiceAccountCredential
from ...auth.auth_schemes import AuthScheme
from ..base_toolset import BaseToolset from ..base_toolset import BaseToolset
from ..base_toolset import ToolPredicate from ..base_toolset import ToolPredicate
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
@ -35,6 +35,9 @@ from .clients.integration_client import IntegrationClient
from .integration_connector_tool import IntegrationConnectorTool from .integration_connector_tool import IntegrationConnectorTool
logger = logging.getLogger(__name__)
# TODO(cheliu): Apply a common toolset interface # TODO(cheliu): Apply a common toolset interface
class ApplicationIntegrationToolset(BaseToolset): class ApplicationIntegrationToolset(BaseToolset):
"""ApplicationIntegrationToolset generates tools from a given Application """ApplicationIntegrationToolset generates tools from a given Application
@ -93,6 +96,8 @@ class ApplicationIntegrationToolset(BaseToolset):
# tool/python function description. # tool/python function description.
tool_instructions: Optional[str] = "", tool_instructions: Optional[str] = "",
service_account_json: Optional[str] = None, service_account_json: Optional[str] = None,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
): ):
"""Args: """Args:
@ -132,6 +137,8 @@ class ApplicationIntegrationToolset(BaseToolset):
self._tool_name_prefix = tool_name_prefix self._tool_name_prefix = tool_name_prefix
self._tool_instructions = tool_instructions self._tool_instructions = tool_instructions
self._service_account_json = service_account_json self._service_account_json = service_account_json
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self.tool_filter = tool_filter self.tool_filter = tool_filter
integration_client = IntegrationClient( integration_client = IntegrationClient(
@ -212,6 +219,27 @@ class ApplicationIntegrationToolset(BaseToolset):
rest_api_tool.configure_auth_scheme(auth_scheme) rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential: if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential) rest_api_tool.configure_auth_credential(auth_credential)
auth_override_enabled = connection_details.get(
"authOverrideEnabled", False
)
if (
self._auth_scheme
and self._auth_credential
and not auth_override_enabled
):
# Case: Auth provided, but override is OFF. Don't use provided auth.
logger.warning(
"Authentication schema and credentials are not used because"
" authOverrideEnabled is not enabled in the connection."
)
connector_auth_scheme = None
connector_auth_credential = None
else:
connector_auth_scheme = self._auth_scheme
connector_auth_credential = self._auth_credential
self._tools.append( self._tools.append(
IntegrationConnectorTool( IntegrationConnectorTool(
name=rest_api_tool.name, name=rest_api_tool.name,
@ -223,6 +251,8 @@ class ApplicationIntegrationToolset(BaseToolset):
action=action, action=action,
operation=operation, operation=operation,
rest_api_tool=rest_api_tool, rest_api_tool=rest_api_tool,
auth_scheme=connector_auth_scheme,
auth_credential=connector_auth_credential,
) )
) )

View File

@ -554,6 +554,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"}, "serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"}, "host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"}, "entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
}, },
} }
@ -580,6 +583,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"}, "serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"}, "host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"}, "entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
"filterClause": {"$ref": "#/components/schemas/filterClause"}, "filterClause": {"$ref": "#/components/schemas/filterClause"},
}, },
} }
@ -603,6 +609,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"}, "serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"}, "host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"}, "entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
}, },
} }
@ -625,6 +634,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"}, "serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"}, "host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"}, "entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
"filterClause": {"$ref": "#/components/schemas/filterClause"}, "filterClause": {"$ref": "#/components/schemas/filterClause"},
}, },
} }
@ -649,6 +661,9 @@ class ConnectionsClient:
"serviceName": {"$ref": "#/components/schemas/serviceName"}, "serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"}, "host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"}, "entity": {"$ref": "#/components/schemas/entity"},
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
}, },
} }
@ -673,6 +688,9 @@ class ConnectionsClient:
"connectorInputPayload": { "connectorInputPayload": {
"$ref": f"#/components/schemas/connectorInputPayload_{action}" "$ref": f"#/components/schemas/connectorInputPayload_{action}"
}, },
"dynamicAuthConfig": {
"$ref": "#/components/schemas/dynamicAuthConfig"
},
}, },
} }

View File

@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any from typing import Any, Dict, Optional, Union
from typing import Dict
from typing import Optional
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.genai.types import FunctionDeclaration from google.genai.types import FunctionDeclaration
from typing_extensions import override from typing_extensions import override
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from .. import BaseTool from .. import BaseTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from ..openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
from ..tool_context import ToolContext from ..tool_context import ToolContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,6 +57,7 @@ class IntegrationConnectorTool(BaseTool):
'entity', 'entity',
'operation', 'operation',
'action', 'action',
'dynamic_auth_config',
] ]
OPTIONAL_FIELDS = [ OPTIONAL_FIELDS = [
@ -75,6 +77,8 @@ class IntegrationConnectorTool(BaseTool):
operation: str, operation: str,
action: str, action: str,
rest_api_tool: RestApiTool, rest_api_tool: RestApiTool,
auth_scheme: Optional[Union[AuthScheme, str]] = None,
auth_credential: Optional[Union[AuthCredential, str]] = None,
): ):
"""Initializes the ApplicationIntegrationTool. """Initializes the ApplicationIntegrationTool.
@ -108,6 +112,8 @@ class IntegrationConnectorTool(BaseTool):
self._operation = operation self._operation = operation
self._action = action self._action = action
self._rest_api_tool = rest_api_tool self._rest_api_tool = rest_api_tool
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
@override @override
def _get_declaration(self) -> FunctionDeclaration: def _get_declaration(self) -> FunctionDeclaration:
@ -126,10 +132,45 @@ class IntegrationConnectorTool(BaseTool):
) )
return function_decl return function_decl
def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str:
if (
auth_credential
and auth_credential.http
and auth_credential.http.credentials
and auth_credential.http.credentials.token
):
return auth_credential.http.credentials.token
return None
@override @override
async def run_async( async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext] self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
tool_auth_handler = ToolAuthHandler.from_tool_context(
tool_context, self._auth_scheme, self._auth_credential
)
auth_result = tool_auth_handler.prepare_auth_credentials()
if auth_result.state == 'pending':
return {
'pending': True,
'message': 'Needs your authorization to access your data.',
}
# Attach parameters from auth into main parameters list
if auth_result.auth_credential:
# Attach parameters from auth into main parameters list
auth_credential_token = self._prepare_dynamic_euc(
auth_result.auth_credential
)
if auth_credential_token:
args['dynamic_auth_config'] = {
'oauth2_auth_code_flow.access_token': auth_credential_token
}
else:
args['dynamic_auth_config'] = {'oauth2_auth_code_flow.access_token': {}}
args['connection_name'] = self._connection_name args['connection_name'] = self._connection_name
args['service_name'] = self._connection_service_name args['service_name'] = self._connection_service_name
args['host'] = self._connection_host args['host'] = self._connection_host

View File

@ -16,9 +16,12 @@ import json
from unittest import mock from unittest import mock
from fastapi.openapi.models import Operation from fastapi.openapi.models import Operation
from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.auth import AuthCredentialTypes
from google.adk.auth import OAuth2Auth
from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest import pytest
@ -162,6 +165,16 @@ def connection_details():
} }
@pytest.fixture
def connection_details_auth_override_enabled():
return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
"authOverrideEnabled": True,
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_initialization_with_integration_and_trigger( async def test_initialization_with_integration_and_trigger(
project, project,
@ -474,3 +487,139 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions tool_name, tool_instructions
) )
def test_init_with_connection_and_custom_auth(
mock_integration_client,
mock_connections_client,
mock_openapi_action_spec_parser,
connection_details_auth_override_enabled,
):
connection_name = "test-connection"
actions_list = ["create", "delete"]
tool_name = "My Actions Tool"
tool_instructions = "Perform actions using this tool."
mock_connections_client.return_value.get_connection_details.return_value = (
connection_details_auth_override_enabled
)
oauth2_data_google_cloud = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://test-url/o/oauth2/auth",
"tokenUrl": "https://test-url/token",
"scopes": {
"https://test-url/auth/test-scope": "test scope",
"https://www.test-url.com/auth/test-scope2": "test scope 2",
},
}
},
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test-client-id",
client_secret="test-client-secret",
),
)
toolset = ApplicationIntegrationToolset(
project,
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_instructions=tool_instructions,
auth_scheme=oauth2_scheme,
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
assert toolset.get_tools()[0].auth_scheme == oauth2_scheme
assert toolset.get_tools()[0].auth_credential == auth_credential
def test_init_with_connection_with_auth_override_disabled_and_custom_auth(
mock_integration_client,
mock_connections_client,
mock_openapi_action_spec_parser,
connection_details,
):
connection_name = "test-connection"
actions_list = ["create", "delete"]
tool_name = "My Actions Tool"
tool_instructions = "Perform actions using this tool."
mock_connections_client.return_value.get_connection_details.return_value = (
connection_details
)
oauth2_data_google_cloud = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://test-url/o/oauth2/auth",
"tokenUrl": "https://test-url/token",
"scopes": {
"https://test-url/auth/test-scope": "test scope",
"https://www.test-url.com/auth/test-scope2": "test scope 2",
},
}
},
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test-client-id",
client_secret="test-client-secret",
),
)
toolset = ApplicationIntegrationToolset(
project,
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_instructions=tool_instructions,
auth_scheme=oauth2_scheme,
auth_credential=auth_credential,
)
mock_integration_client.assert_called_once_with(
project, location, None, None, connection_name, None, actions_list, None
)
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
assert not toolset.get_tools()[0].auth_scheme
assert not toolset.get_tools()[0].auth_credential

View File

@ -14,11 +14,14 @@
from unittest import mock from unittest import mock
from google.adk.auth import AuthCredential
from google.adk.auth import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration from google.genai.types import FunctionDeclaration
from google.genai.types import Schema from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type from google.genai.types import Type
import pytest import pytest
@ -67,6 +70,30 @@ def integration_tool(mock_rest_api_tool):
) )
@pytest.fixture
def integration_tool_with_auth(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
auth_scheme=None,
auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="mocked_token"),
),
),
)
def test_get_declaration(integration_tool): def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration.""" """Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration() declaration = integration_tool._get_declaration()
@ -123,3 +150,98 @@ async def test_run_async(integration_tool, mock_rest_api_tool):
# Assert the result is what the mocked call returned # Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"} assert result == {"status": "success", "data": "mock_data"}
@pytest.mark.asyncio
async def test_run_with_auth_async_none_token(
integration_tool_with_auth, mock_rest_api_tool
):
"""Tests run_async when auth credential token is None."""
input_args = {"user_id": "user456", "filter": "some_filter"}
expected_call_args = {
"user_id": "user456",
"filter": "some_filter",
"dynamic_auth_config": {"oauth2_auth_code_flow.access_token": {}},
"connection_name": "test-conn",
"service_name": "test-service",
"host": "test.example.com",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
with mock.patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = mock.MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
# Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None
mock_auth_credential_without_token = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=None), # Token is None
),
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = (
mock_auth_credential_without_token
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
result = await integration_tool_with_auth.run_async(
args=input_args, tool_context={}
)
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}
@pytest.mark.asyncio
async def test_run_with_auth_async(
integration_tool_with_auth, mock_rest_api_tool
):
"""Tests the async execution with auth delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"dynamic_auth_config": {
"oauth2_auth_code_flow.access_token": "mocked_token"
},
"connection_name": "test-conn",
"service_name": "test-service",
"host": "test.example.com",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
with mock.patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = mock.MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"done"
)
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="mocked_token"),
),
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
result = await integration_tool_with_auth.run_async(
args=input_args, tool_context={}
)
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context={}
)
assert result == {"status": "success", "data": "mock_data"}