mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 07:04:51 -06:00
Added support for dynamic auth in integration connector tool
PiperOrigin-RevId: 759676602
This commit is contained in:
parent
2f006264ce
commit
6e0ea01fcb
@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from typing_extensions import override
|
||||
@ -24,6 +23,7 @@ from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_credential import AuthCredentialTypes
|
||||
from ...auth.auth_credential import ServiceAccount
|
||||
from ...auth.auth_credential import ServiceAccountCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..base_toolset import BaseToolset
|
||||
from ..base_toolset import ToolPredicate
|
||||
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
@ -35,6 +35,9 @@ from .clients.integration_client import IntegrationClient
|
||||
from .integration_connector_tool import IntegrationConnectorTool
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(cheliu): Apply a common toolset interface
|
||||
class ApplicationIntegrationToolset(BaseToolset):
|
||||
"""ApplicationIntegrationToolset generates tools from a given Application
|
||||
@ -93,6 +96,8 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
# tool/python function description.
|
||||
tool_instructions: Optional[str] = "",
|
||||
service_account_json: Optional[str] = None,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
):
|
||||
"""Args:
|
||||
@ -132,6 +137,8 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
self._tool_name_prefix = tool_name_prefix
|
||||
self._tool_instructions = tool_instructions
|
||||
self._service_account_json = service_account_json
|
||||
self._auth_scheme = auth_scheme
|
||||
self._auth_credential = auth_credential
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
integration_client = IntegrationClient(
|
||||
@ -212,6 +219,27 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
rest_api_tool.configure_auth_scheme(auth_scheme)
|
||||
if auth_credential:
|
||||
rest_api_tool.configure_auth_credential(auth_credential)
|
||||
|
||||
auth_override_enabled = connection_details.get(
|
||||
"authOverrideEnabled", False
|
||||
)
|
||||
|
||||
if (
|
||||
self._auth_scheme
|
||||
and self._auth_credential
|
||||
and not auth_override_enabled
|
||||
):
|
||||
# Case: Auth provided, but override is OFF. Don't use provided auth.
|
||||
logger.warning(
|
||||
"Authentication schema and credentials are not used because"
|
||||
" authOverrideEnabled is not enabled in the connection."
|
||||
)
|
||||
connector_auth_scheme = None
|
||||
connector_auth_credential = None
|
||||
else:
|
||||
connector_auth_scheme = self._auth_scheme
|
||||
connector_auth_credential = self._auth_credential
|
||||
|
||||
self._tools.append(
|
||||
IntegrationConnectorTool(
|
||||
name=rest_api_tool.name,
|
||||
@ -223,6 +251,8 @@ class ApplicationIntegrationToolset(BaseToolset):
|
||||
action=action,
|
||||
operation=operation,
|
||||
rest_api_tool=rest_api_tool,
|
||||
auth_scheme=connector_auth_scheme,
|
||||
auth_credential=connector_auth_credential,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -554,6 +554,9 @@ class ConnectionsClient:
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
"dynamicAuthConfig": {
|
||||
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -580,6 +583,9 @@ class ConnectionsClient:
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
"dynamicAuthConfig": {
|
||||
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||
},
|
||||
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
||||
},
|
||||
}
|
||||
@ -603,6 +609,9 @@ class ConnectionsClient:
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
"dynamicAuthConfig": {
|
||||
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -625,6 +634,9 @@ class ConnectionsClient:
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
"dynamicAuthConfig": {
|
||||
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||
},
|
||||
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
||||
},
|
||||
}
|
||||
@ -649,6 +661,9 @@ class ConnectionsClient:
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
"dynamicAuthConfig": {
|
||||
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -673,6 +688,9 @@ class ConnectionsClient:
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
|
||||
},
|
||||
"dynamicAuthConfig": {
|
||||
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -12,20 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from .. import BaseTool
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from ..openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -56,6 +57,7 @@ class IntegrationConnectorTool(BaseTool):
|
||||
'entity',
|
||||
'operation',
|
||||
'action',
|
||||
'dynamic_auth_config',
|
||||
]
|
||||
|
||||
OPTIONAL_FIELDS = [
|
||||
@ -75,6 +77,8 @@ class IntegrationConnectorTool(BaseTool):
|
||||
operation: str,
|
||||
action: str,
|
||||
rest_api_tool: RestApiTool,
|
||||
auth_scheme: Optional[Union[AuthScheme, str]] = None,
|
||||
auth_credential: Optional[Union[AuthCredential, str]] = None,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationTool.
|
||||
|
||||
@ -108,6 +112,8 @@ class IntegrationConnectorTool(BaseTool):
|
||||
self._operation = operation
|
||||
self._action = action
|
||||
self._rest_api_tool = rest_api_tool
|
||||
self._auth_scheme = auth_scheme
|
||||
self._auth_credential = auth_credential
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
@ -126,10 +132,45 @@ class IntegrationConnectorTool(BaseTool):
|
||||
)
|
||||
return function_decl
|
||||
|
||||
def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str:
|
||||
if (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
return auth_credential.http.credentials.token
|
||||
return None
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
tool_auth_handler = ToolAuthHandler.from_tool_context(
|
||||
tool_context, self._auth_scheme, self._auth_credential
|
||||
)
|
||||
auth_result = tool_auth_handler.prepare_auth_credentials()
|
||||
|
||||
if auth_result.state == 'pending':
|
||||
return {
|
||||
'pending': True,
|
||||
'message': 'Needs your authorization to access your data.',
|
||||
}
|
||||
|
||||
# Attach parameters from auth into main parameters list
|
||||
if auth_result.auth_credential:
|
||||
# Attach parameters from auth into main parameters list
|
||||
auth_credential_token = self._prepare_dynamic_euc(
|
||||
auth_result.auth_credential
|
||||
)
|
||||
if auth_credential_token:
|
||||
args['dynamic_auth_config'] = {
|
||||
'oauth2_auth_code_flow.access_token': auth_credential_token
|
||||
}
|
||||
else:
|
||||
args['dynamic_auth_config'] = {'oauth2_auth_code_flow.access_token': {}}
|
||||
|
||||
args['connection_name'] = self._connection_name
|
||||
args['service_name'] = self._connection_service_name
|
||||
args['host'] = self._connection_host
|
||||
|
@ -16,9 +16,12 @@ import json
|
||||
from unittest import mock
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.auth import AuthCredentialTypes
|
||||
from google.adk.auth import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
||||
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||
import pytest
|
||||
@ -162,6 +165,16 @@ def connection_details():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_details_auth_override_enabled():
|
||||
return {
|
||||
"serviceName": "test-service",
|
||||
"host": "test.host",
|
||||
"name": "test-connection",
|
||||
"authOverrideEnabled": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_integration_and_trigger(
|
||||
project,
|
||||
@ -474,3 +487,139 @@ def test_initialization_with_connection_details(
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name, tool_instructions
|
||||
)
|
||||
|
||||
|
||||
def test_init_with_connection_and_custom_auth(
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_action_spec_parser,
|
||||
connection_details_auth_override_enabled,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
actions_list = ["create", "delete"]
|
||||
tool_name = "My Actions Tool"
|
||||
tool_instructions = "Perform actions using this tool."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||
connection_details_auth_override_enabled
|
||||
)
|
||||
|
||||
oauth2_data_google_cloud = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://test-url/o/oauth2/auth",
|
||||
"tokenUrl": "https://test-url/token",
|
||||
"scopes": {
|
||||
"https://test-url/auth/test-scope": "test scope",
|
||||
"https://www.test-url.com/auth/test-scope2": "test scope 2",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
),
|
||||
)
|
||||
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
actions=actions_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
auth_scheme=oauth2_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, None, None, connection_name, None, actions_list, None
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name, tool_instructions
|
||||
)
|
||||
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "list_issues_operation"
|
||||
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||
assert toolset.get_tools()[0].action == "CustomAction"
|
||||
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
|
||||
assert toolset.get_tools()[0].auth_scheme == oauth2_scheme
|
||||
assert toolset.get_tools()[0].auth_credential == auth_credential
|
||||
|
||||
|
||||
def test_init_with_connection_with_auth_override_disabled_and_custom_auth(
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_action_spec_parser,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
actions_list = ["create", "delete"]
|
||||
tool_name = "My Actions Tool"
|
||||
tool_instructions = "Perform actions using this tool."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||
connection_details
|
||||
)
|
||||
|
||||
oauth2_data_google_cloud = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://test-url/o/oauth2/auth",
|
||||
"tokenUrl": "https://test-url/token",
|
||||
"scopes": {
|
||||
"https://test-url/auth/test-scope": "test scope",
|
||||
"https://www.test-url.com/auth/test-scope2": "test scope 2",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
),
|
||||
)
|
||||
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
actions=actions_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
auth_scheme=oauth2_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, None, None, connection_name, None, actions_list, None
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name, tool_instructions
|
||||
)
|
||||
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "list_issues_operation"
|
||||
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||
assert toolset.get_tools()[0].action == "CustomAction"
|
||||
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
|
||||
assert not toolset.get_tools()[0].auth_scheme
|
||||
assert not toolset.get_tools()[0].auth_credential
|
||||
|
@ -14,11 +14,14 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.auth import AuthCredential
|
||||
from google.adk.auth import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
from google.genai.types import Tool
|
||||
from google.genai.types import Type
|
||||
import pytest
|
||||
|
||||
@ -67,6 +70,30 @@ def integration_tool(mock_rest_api_tool):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_tool_with_auth(mock_rest_api_tool):
|
||||
"""Fixture for an IntegrationConnectorTool instance."""
|
||||
return IntegrationConnectorTool(
|
||||
name="test_integration_tool",
|
||||
description="Test integration tool description.",
|
||||
connection_name="test-conn",
|
||||
connection_host="test.example.com",
|
||||
connection_service_name="test-service",
|
||||
entity="TestEntity",
|
||||
operation="LIST",
|
||||
action="TestAction",
|
||||
rest_api_tool=mock_rest_api_tool,
|
||||
auth_scheme=None,
|
||||
auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token="mocked_token"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_declaration(integration_tool):
|
||||
"""Tests the generation of the function declaration."""
|
||||
declaration = integration_tool._get_declaration()
|
||||
@ -123,3 +150,98 @@ async def test_run_async(integration_tool, mock_rest_api_tool):
|
||||
|
||||
# Assert the result is what the mocked call returned
|
||||
assert result == {"status": "success", "data": "mock_data"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_auth_async_none_token(
|
||||
integration_tool_with_auth, mock_rest_api_tool
|
||||
):
|
||||
"""Tests run_async when auth credential token is None."""
|
||||
input_args = {"user_id": "user456", "filter": "some_filter"}
|
||||
expected_call_args = {
|
||||
"user_id": "user456",
|
||||
"filter": "some_filter",
|
||||
"dynamic_auth_config": {"oauth2_auth_code_flow.access_token": {}},
|
||||
"connection_name": "test-conn",
|
||||
"service_name": "test-service",
|
||||
"host": "test.example.com",
|
||||
"entity": "TestEntity",
|
||||
"operation": "LIST",
|
||||
"action": "TestAction",
|
||||
}
|
||||
|
||||
with mock.patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||
) as mock_from_tool_context:
|
||||
mock_tool_auth_handler_instance = mock.MagicMock()
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"done"
|
||||
)
|
||||
# Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None
|
||||
mock_auth_credential_without_token = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=None), # Token is None
|
||||
),
|
||||
)
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = (
|
||||
mock_auth_credential_without_token
|
||||
)
|
||||
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||
|
||||
result = await integration_tool_with_auth.run_async(
|
||||
args=input_args, tool_context={}
|
||||
)
|
||||
|
||||
mock_rest_api_tool.call.assert_called_once_with(
|
||||
args=expected_call_args, tool_context={}
|
||||
)
|
||||
assert result == {"status": "success", "data": "mock_data"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_auth_async(
|
||||
integration_tool_with_auth, mock_rest_api_tool
|
||||
):
|
||||
"""Tests the async execution with auth delegates correctly to the RestApiTool."""
|
||||
input_args = {"user_id": "user123", "page_size": 10}
|
||||
expected_call_args = {
|
||||
"user_id": "user123",
|
||||
"page_size": 10,
|
||||
"dynamic_auth_config": {
|
||||
"oauth2_auth_code_flow.access_token": "mocked_token"
|
||||
},
|
||||
"connection_name": "test-conn",
|
||||
"service_name": "test-service",
|
||||
"host": "test.example.com",
|
||||
"entity": "TestEntity",
|
||||
"operation": "LIST",
|
||||
"action": "TestAction",
|
||||
}
|
||||
|
||||
with mock.patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||
) as mock_from_tool_context:
|
||||
mock_tool_auth_handler_instance = mock.MagicMock()
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"done"
|
||||
)
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"done"
|
||||
)
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token="mocked_token"),
|
||||
),
|
||||
)
|
||||
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||
result = await integration_tool_with_auth.run_async(
|
||||
args=input_args, tool_context={}
|
||||
)
|
||||
mock_rest_api_tool.call.assert_called_once_with(
|
||||
args=expected_call_args, tool_context={}
|
||||
)
|
||||
assert result == {"status": "success", "data": "mock_data"}
|
||||
|
Loading…
Reference in New Issue
Block a user