diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py index 027dd7c..b6e0b39 100644 --- a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py +++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List -from typing import Optional -from typing import Union +import logging +from typing import List, Optional, Union from fastapi.openapi.models import HTTPBearer from typing_extensions import override @@ -24,6 +23,7 @@ from ...auth.auth_credential import AuthCredential from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import ServiceAccount from ...auth.auth_credential import ServiceAccountCredential +from ...auth.auth_schemes import AuthScheme from ..base_toolset import BaseToolset from ..base_toolset import ToolPredicate from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential @@ -35,6 +35,9 @@ from .clients.integration_client import IntegrationClient from .integration_connector_tool import IntegrationConnectorTool +logger = logging.getLogger(__name__) + + # TODO(cheliu): Apply a common toolset interface class ApplicationIntegrationToolset(BaseToolset): """ApplicationIntegrationToolset generates tools from a given Application @@ -93,6 +96,8 @@ class ApplicationIntegrationToolset(BaseToolset): # tool/python function description. tool_instructions: Optional[str] = "", service_account_json: Optional[str] = None, + auth_scheme: Optional[AuthScheme] = None, + auth_credential: Optional[AuthCredential] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, ): """Args: @@ -132,6 +137,8 @@ class ApplicationIntegrationToolset(BaseToolset): self._tool_name_prefix = tool_name_prefix self._tool_instructions = tool_instructions self._service_account_json = service_account_json + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential self.tool_filter = tool_filter integration_client = IntegrationClient( @@ -212,6 +219,27 @@ class ApplicationIntegrationToolset(BaseToolset): rest_api_tool.configure_auth_scheme(auth_scheme) if auth_credential: rest_api_tool.configure_auth_credential(auth_credential) + + auth_override_enabled = connection_details.get( + "authOverrideEnabled", False + ) + + if ( + self._auth_scheme + and self._auth_credential + and not auth_override_enabled + ): + # Case: Auth provided, but override is OFF. Don't use provided auth. + logger.warning( + "Authentication schema and credentials are not used because" + " authOverrideEnabled is not enabled in the connection." + ) + connector_auth_scheme = None + connector_auth_credential = None + else: + connector_auth_scheme = self._auth_scheme + connector_auth_credential = self._auth_credential + self._tools.append( IntegrationConnectorTool( name=rest_api_tool.name, @@ -223,6 +251,8 @@ class ApplicationIntegrationToolset(BaseToolset): action=action, operation=operation, rest_api_tool=rest_api_tool, + auth_scheme=connector_auth_scheme, + auth_credential=connector_auth_credential, ) ) diff --git a/src/google/adk/tools/application_integration_tool/clients/connections_client.py b/src/google/adk/tools/application_integration_tool/clients/connections_client.py index 3fed5f2..b56b5cf 100644 --- a/src/google/adk/tools/application_integration_tool/clients/connections_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/connections_client.py @@ -554,6 +554,9 @@ class ConnectionsClient: "serviceName": {"$ref": "#/components/schemas/serviceName"}, "host": {"$ref": "#/components/schemas/host"}, "entity": {"$ref": "#/components/schemas/entity"}, + "dynamicAuthConfig": { + "$ref": "#/components/schemas/dynamicAuthConfig" + }, }, } @@ -580,6 +583,9 @@ class ConnectionsClient: "serviceName": {"$ref": "#/components/schemas/serviceName"}, "host": {"$ref": "#/components/schemas/host"}, "entity": {"$ref": "#/components/schemas/entity"}, + "dynamicAuthConfig": { + "$ref": "#/components/schemas/dynamicAuthConfig" + }, "filterClause": {"$ref": "#/components/schemas/filterClause"}, }, } @@ -603,6 +609,9 @@ class ConnectionsClient: "serviceName": {"$ref": "#/components/schemas/serviceName"}, "host": {"$ref": "#/components/schemas/host"}, "entity": {"$ref": "#/components/schemas/entity"}, + "dynamicAuthConfig": { + "$ref": "#/components/schemas/dynamicAuthConfig" + }, }, } @@ -625,6 +634,9 @@ class ConnectionsClient: "serviceName": {"$ref": "#/components/schemas/serviceName"}, "host": {"$ref": "#/components/schemas/host"}, "entity": {"$ref": "#/components/schemas/entity"}, + "dynamicAuthConfig": { + "$ref": "#/components/schemas/dynamicAuthConfig" + }, "filterClause": {"$ref": "#/components/schemas/filterClause"}, }, } @@ -649,6 +661,9 @@ class ConnectionsClient: "serviceName": {"$ref": "#/components/schemas/serviceName"}, "host": {"$ref": "#/components/schemas/host"}, "entity": {"$ref": "#/components/schemas/entity"}, + "dynamicAuthConfig": { + "$ref": "#/components/schemas/dynamicAuthConfig" + }, }, } @@ -673,6 +688,9 @@ class ConnectionsClient: "connectorInputPayload": { "$ref": f"#/components/schemas/connectorInputPayload_{action}" }, + "dynamicAuthConfig": { + "$ref": "#/components/schemas/dynamicAuthConfig" + }, }, } diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 1a112a4..58ec637 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Any -from typing import Dict -from typing import Optional +from typing import Any, Dict, Optional, Union -from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool -from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema from google.genai.types import FunctionDeclaration from typing_extensions import override +from ...auth.auth_credential import AuthCredential +from ...auth.auth_schemes import AuthScheme from .. import BaseTool +from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool +from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema +from ..openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler from ..tool_context import ToolContext + logger = logging.getLogger(__name__) @@ -56,6 +57,7 @@ class IntegrationConnectorTool(BaseTool): 'entity', 'operation', 'action', + 'dynamic_auth_config', ] OPTIONAL_FIELDS = [ @@ -75,6 +77,8 @@ class IntegrationConnectorTool(BaseTool): operation: str, action: str, rest_api_tool: RestApiTool, + auth_scheme: Optional[Union[AuthScheme, str]] = None, + auth_credential: Optional[Union[AuthCredential, str]] = None, ): """Initializes the ApplicationIntegrationTool. @@ -108,6 +112,8 @@ class IntegrationConnectorTool(BaseTool): self._operation = operation self._action = action self._rest_api_tool = rest_api_tool + self._auth_scheme = auth_scheme + self._auth_credential = auth_credential @override def _get_declaration(self) -> FunctionDeclaration: @@ -126,10 +132,45 @@ class IntegrationConnectorTool(BaseTool): ) return function_decl + def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str: + if ( + auth_credential + and auth_credential.http + and auth_credential.http.credentials + and auth_credential.http.credentials.token + ): + return auth_credential.http.credentials.token + return None + @override async def run_async( self, *, args: dict[str, Any], tool_context: Optional[ToolContext] ) -> Dict[str, Any]: + + tool_auth_handler = ToolAuthHandler.from_tool_context( + tool_context, self._auth_scheme, self._auth_credential + ) + auth_result = tool_auth_handler.prepare_auth_credentials() + + if auth_result.state == 'pending': + return { + 'pending': True, + 'message': 'Needs your authorization to access your data.', + } + + # Attach parameters from auth into main parameters list + if auth_result.auth_credential: + # Attach parameters from auth into main parameters list + auth_credential_token = self._prepare_dynamic_euc( + auth_result.auth_credential + ) + if auth_credential_token: + args['dynamic_auth_config'] = { + 'oauth2_auth_code_flow.access_token': auth_credential_token + } + else: + args['dynamic_auth_config'] = {'oauth2_auth_code_flow.access_token': {}} + args['connection_name'] = self._connection_name args['service_name'] = self._connection_service_name args['host'] = self._connection_host diff --git a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py index 3dc2393..a95e29f 100644 --- a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +++ b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py @@ -16,9 +16,12 @@ import json from unittest import mock from fastapi.openapi.models import Operation from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.auth import AuthCredentialTypes +from google.adk.auth import OAuth2Auth from google.adk.auth.auth_credential import AuthCredential from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool +from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint import pytest @@ -162,6 +165,16 @@ def connection_details(): } +@pytest.fixture +def connection_details_auth_override_enabled(): + return { + "serviceName": "test-service", + "host": "test.host", + "name": "test-connection", + "authOverrideEnabled": True, + } + + @pytest.mark.asyncio async def test_initialization_with_integration_and_trigger( project, @@ -474,3 +487,139 @@ def test_initialization_with_connection_details( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( tool_name, tool_instructions ) + + +def test_init_with_connection_and_custom_auth( + mock_integration_client, + mock_connections_client, + mock_openapi_action_spec_parser, + connection_details_auth_override_enabled, +): + connection_name = "test-connection" + actions_list = ["create", "delete"] + tool_name = "My Actions Tool" + tool_instructions = "Perform actions using this tool." + mock_connections_client.return_value.get_connection_details.return_value = ( + connection_details_auth_override_enabled + ) + + oauth2_data_google_cloud = { + "type": "oauth2", + "flows": { + "authorizationCode": { + "authorizationUrl": "https://test-url/o/oauth2/auth", + "tokenUrl": "https://test-url/token", + "scopes": { + "https://test-url/auth/test-scope": "test scope", + "https://www.test-url.com/auth/test-scope2": "test scope 2", + }, + } + }, + } + + oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test-client-id", + client_secret="test-client-secret", + ), + ) + + toolset = ApplicationIntegrationToolset( + project, + location, + connection=connection_name, + actions=actions_list, + tool_name=tool_name, + tool_instructions=tool_instructions, + auth_scheme=oauth2_scheme, + auth_credential=auth_credential, + ) + mock_integration_client.assert_called_once_with( + project, location, None, None, connection_name, None, actions_list, None + ) + mock_connections_client.assert_called_once_with( + project, location, connection_name, None + ) + mock_connections_client.return_value.get_connection_details.assert_called_once() + mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( + tool_name, tool_instructions + ) + mock_openapi_action_spec_parser.return_value.parse.assert_called_once() + assert len(toolset.get_tools()) == 1 + assert toolset.get_tools()[0].name == "list_issues_operation" + assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool) + assert toolset.get_tools()[0].action == "CustomAction" + assert toolset.get_tools()[0].operation == "EXECUTE_ACTION" + assert toolset.get_tools()[0].auth_scheme == oauth2_scheme + assert toolset.get_tools()[0].auth_credential == auth_credential + + +def test_init_with_connection_with_auth_override_disabled_and_custom_auth( + mock_integration_client, + mock_connections_client, + mock_openapi_action_spec_parser, + connection_details, +): + connection_name = "test-connection" + actions_list = ["create", "delete"] + tool_name = "My Actions Tool" + tool_instructions = "Perform actions using this tool." + mock_connections_client.return_value.get_connection_details.return_value = ( + connection_details + ) + + oauth2_data_google_cloud = { + "type": "oauth2", + "flows": { + "authorizationCode": { + "authorizationUrl": "https://test-url/o/oauth2/auth", + "tokenUrl": "https://test-url/token", + "scopes": { + "https://test-url/auth/test-scope": "test scope", + "https://www.test-url.com/auth/test-scope2": "test scope 2", + }, + } + }, + } + + oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test-client-id", + client_secret="test-client-secret", + ), + ) + + toolset = ApplicationIntegrationToolset( + project, + location, + connection=connection_name, + actions=actions_list, + tool_name=tool_name, + tool_instructions=tool_instructions, + auth_scheme=oauth2_scheme, + auth_credential=auth_credential, + ) + mock_integration_client.assert_called_once_with( + project, location, None, None, connection_name, None, actions_list, None + ) + mock_connections_client.assert_called_once_with( + project, location, connection_name, None + ) + mock_connections_client.return_value.get_connection_details.assert_called_once() + mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( + tool_name, tool_instructions + ) + mock_openapi_action_spec_parser.return_value.parse.assert_called_once() + assert len(toolset.get_tools()) == 1 + assert toolset.get_tools()[0].name == "list_issues_operation" + assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool) + assert toolset.get_tools()[0].action == "CustomAction" + assert toolset.get_tools()[0].operation == "EXECUTE_ACTION" + assert not toolset.get_tools()[0].auth_scheme + assert not toolset.get_tools()[0].auth_credential diff --git a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py index 93ed4bc..16aef90 100644 --- a/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py +++ b/tests/unittests/tools/application_integration_tool/test_integration_connector_tool.py @@ -14,11 +14,14 @@ from unittest import mock +from google.adk.auth import AuthCredential +from google.adk.auth import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool from google.genai.types import FunctionDeclaration from google.genai.types import Schema -from google.genai.types import Tool from google.genai.types import Type import pytest @@ -67,6 +70,30 @@ def integration_tool(mock_rest_api_tool): ) +@pytest.fixture +def integration_tool_with_auth(mock_rest_api_tool): + """Fixture for an IntegrationConnectorTool instance.""" + return IntegrationConnectorTool( + name="test_integration_tool", + description="Test integration tool description.", + connection_name="test-conn", + connection_host="test.example.com", + connection_service_name="test-service", + entity="TestEntity", + operation="LIST", + action="TestAction", + rest_api_tool=mock_rest_api_tool, + auth_scheme=None, + auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="mocked_token"), + ), + ), + ) + + def test_get_declaration(integration_tool): """Tests the generation of the function declaration.""" declaration = integration_tool._get_declaration() @@ -123,3 +150,98 @@ async def test_run_async(integration_tool, mock_rest_api_tool): # Assert the result is what the mocked call returned assert result == {"status": "success", "data": "mock_data"} + + +@pytest.mark.asyncio +async def test_run_with_auth_async_none_token( + integration_tool_with_auth, mock_rest_api_tool +): + """Tests run_async when auth credential token is None.""" + input_args = {"user_id": "user456", "filter": "some_filter"} + expected_call_args = { + "user_id": "user456", + "filter": "some_filter", + "dynamic_auth_config": {"oauth2_auth_code_flow.access_token": {}}, + "connection_name": "test-conn", + "service_name": "test-service", + "host": "test.example.com", + "entity": "TestEntity", + "operation": "LIST", + "action": "TestAction", + } + + with mock.patch( + "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" + ) as mock_from_tool_context: + mock_tool_auth_handler_instance = mock.MagicMock() + mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( + "done" + ) + # Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None + mock_auth_credential_without_token = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=None), # Token is None + ), + ) + mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = ( + mock_auth_credential_without_token + ) + mock_from_tool_context.return_value = mock_tool_auth_handler_instance + + result = await integration_tool_with_auth.run_async( + args=input_args, tool_context={} + ) + + mock_rest_api_tool.call.assert_called_once_with( + args=expected_call_args, tool_context={} + ) + assert result == {"status": "success", "data": "mock_data"} + + +@pytest.mark.asyncio +async def test_run_with_auth_async( + integration_tool_with_auth, mock_rest_api_tool +): + """Tests the async execution with auth delegates correctly to the RestApiTool.""" + input_args = {"user_id": "user123", "page_size": 10} + expected_call_args = { + "user_id": "user123", + "page_size": 10, + "dynamic_auth_config": { + "oauth2_auth_code_flow.access_token": "mocked_token" + }, + "connection_name": "test-conn", + "service_name": "test-service", + "host": "test.example.com", + "entity": "TestEntity", + "operation": "LIST", + "action": "TestAction", + } + + with mock.patch( + "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context" + ) as mock_from_tool_context: + mock_tool_auth_handler_instance = mock.MagicMock() + mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( + "done" + ) + mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = ( + "done" + ) + mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token="mocked_token"), + ), + ) + mock_from_tool_context.return_value = mock_tool_auth_handler_instance + result = await integration_tool_with_auth.run_async( + args=input_args, tool_context={} + ) + mock_rest_api_tool.call.assert_called_once_with( + args=expected_call_args, tool_context={} + ) + assert result == {"status": "success", "data": "mock_data"}