mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
Added support for dynamic auth in integration connector tool
PiperOrigin-RevId: 759676602
This commit is contained in:
parent
2f006264ce
commit
6e0ea01fcb
@ -12,9 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List
|
import logging
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from fastapi.openapi.models import HTTPBearer
|
from fastapi.openapi.models import HTTPBearer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -24,6 +23,7 @@ from ...auth.auth_credential import AuthCredential
|
|||||||
from ...auth.auth_credential import AuthCredentialTypes
|
from ...auth.auth_credential import AuthCredentialTypes
|
||||||
from ...auth.auth_credential import ServiceAccount
|
from ...auth.auth_credential import ServiceAccount
|
||||||
from ...auth.auth_credential import ServiceAccountCredential
|
from ...auth.auth_credential import ServiceAccountCredential
|
||||||
|
from ...auth.auth_schemes import AuthScheme
|
||||||
from ..base_toolset import BaseToolset
|
from ..base_toolset import BaseToolset
|
||||||
from ..base_toolset import ToolPredicate
|
from ..base_toolset import ToolPredicate
|
||||||
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||||
@ -35,6 +35,9 @@ from .clients.integration_client import IntegrationClient
|
|||||||
from .integration_connector_tool import IntegrationConnectorTool
|
from .integration_connector_tool import IntegrationConnectorTool
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO(cheliu): Apply a common toolset interface
|
# TODO(cheliu): Apply a common toolset interface
|
||||||
class ApplicationIntegrationToolset(BaseToolset):
|
class ApplicationIntegrationToolset(BaseToolset):
|
||||||
"""ApplicationIntegrationToolset generates tools from a given Application
|
"""ApplicationIntegrationToolset generates tools from a given Application
|
||||||
@ -93,6 +96,8 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
# tool/python function description.
|
# tool/python function description.
|
||||||
tool_instructions: Optional[str] = "",
|
tool_instructions: Optional[str] = "",
|
||||||
service_account_json: Optional[str] = None,
|
service_account_json: Optional[str] = None,
|
||||||
|
auth_scheme: Optional[AuthScheme] = None,
|
||||||
|
auth_credential: Optional[AuthCredential] = None,
|
||||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||||
):
|
):
|
||||||
"""Args:
|
"""Args:
|
||||||
@ -132,6 +137,8 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
self._tool_name_prefix = tool_name_prefix
|
self._tool_name_prefix = tool_name_prefix
|
||||||
self._tool_instructions = tool_instructions
|
self._tool_instructions = tool_instructions
|
||||||
self._service_account_json = service_account_json
|
self._service_account_json = service_account_json
|
||||||
|
self._auth_scheme = auth_scheme
|
||||||
|
self._auth_credential = auth_credential
|
||||||
self.tool_filter = tool_filter
|
self.tool_filter = tool_filter
|
||||||
|
|
||||||
integration_client = IntegrationClient(
|
integration_client = IntegrationClient(
|
||||||
@ -212,6 +219,27 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
rest_api_tool.configure_auth_scheme(auth_scheme)
|
rest_api_tool.configure_auth_scheme(auth_scheme)
|
||||||
if auth_credential:
|
if auth_credential:
|
||||||
rest_api_tool.configure_auth_credential(auth_credential)
|
rest_api_tool.configure_auth_credential(auth_credential)
|
||||||
|
|
||||||
|
auth_override_enabled = connection_details.get(
|
||||||
|
"authOverrideEnabled", False
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._auth_scheme
|
||||||
|
and self._auth_credential
|
||||||
|
and not auth_override_enabled
|
||||||
|
):
|
||||||
|
# Case: Auth provided, but override is OFF. Don't use provided auth.
|
||||||
|
logger.warning(
|
||||||
|
"Authentication schema and credentials are not used because"
|
||||||
|
" authOverrideEnabled is not enabled in the connection."
|
||||||
|
)
|
||||||
|
connector_auth_scheme = None
|
||||||
|
connector_auth_credential = None
|
||||||
|
else:
|
||||||
|
connector_auth_scheme = self._auth_scheme
|
||||||
|
connector_auth_credential = self._auth_credential
|
||||||
|
|
||||||
self._tools.append(
|
self._tools.append(
|
||||||
IntegrationConnectorTool(
|
IntegrationConnectorTool(
|
||||||
name=rest_api_tool.name,
|
name=rest_api_tool.name,
|
||||||
@ -223,6 +251,8 @@ class ApplicationIntegrationToolset(BaseToolset):
|
|||||||
action=action,
|
action=action,
|
||||||
operation=operation,
|
operation=operation,
|
||||||
rest_api_tool=rest_api_tool,
|
rest_api_tool=rest_api_tool,
|
||||||
|
auth_scheme=connector_auth_scheme,
|
||||||
|
auth_credential=connector_auth_credential,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -554,6 +554,9 @@ class ConnectionsClient:
|
|||||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||||
"host": {"$ref": "#/components/schemas/host"},
|
"host": {"$ref": "#/components/schemas/host"},
|
||||||
"entity": {"$ref": "#/components/schemas/entity"},
|
"entity": {"$ref": "#/components/schemas/entity"},
|
||||||
|
"dynamicAuthConfig": {
|
||||||
|
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -580,6 +583,9 @@ class ConnectionsClient:
|
|||||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||||
"host": {"$ref": "#/components/schemas/host"},
|
"host": {"$ref": "#/components/schemas/host"},
|
||||||
"entity": {"$ref": "#/components/schemas/entity"},
|
"entity": {"$ref": "#/components/schemas/entity"},
|
||||||
|
"dynamicAuthConfig": {
|
||||||
|
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||||
|
},
|
||||||
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -603,6 +609,9 @@ class ConnectionsClient:
|
|||||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||||
"host": {"$ref": "#/components/schemas/host"},
|
"host": {"$ref": "#/components/schemas/host"},
|
||||||
"entity": {"$ref": "#/components/schemas/entity"},
|
"entity": {"$ref": "#/components/schemas/entity"},
|
||||||
|
"dynamicAuthConfig": {
|
||||||
|
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -625,6 +634,9 @@ class ConnectionsClient:
|
|||||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||||
"host": {"$ref": "#/components/schemas/host"},
|
"host": {"$ref": "#/components/schemas/host"},
|
||||||
"entity": {"$ref": "#/components/schemas/entity"},
|
"entity": {"$ref": "#/components/schemas/entity"},
|
||||||
|
"dynamicAuthConfig": {
|
||||||
|
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||||
|
},
|
||||||
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -649,6 +661,9 @@ class ConnectionsClient:
|
|||||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||||
"host": {"$ref": "#/components/schemas/host"},
|
"host": {"$ref": "#/components/schemas/host"},
|
||||||
"entity": {"$ref": "#/components/schemas/entity"},
|
"entity": {"$ref": "#/components/schemas/entity"},
|
||||||
|
"dynamicAuthConfig": {
|
||||||
|
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -673,6 +688,9 @@ class ConnectionsClient:
|
|||||||
"connectorInputPayload": {
|
"connectorInputPayload": {
|
||||||
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
|
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
|
||||||
},
|
},
|
||||||
|
"dynamicAuthConfig": {
|
||||||
|
"$ref": "#/components/schemas/dynamicAuthConfig"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,20 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Dict, Optional, Union
|
||||||
from typing import Dict
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
|
||||||
from google.genai.types import FunctionDeclaration
|
from google.genai.types import FunctionDeclaration
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ...auth.auth_credential import AuthCredential
|
||||||
|
from ...auth.auth_schemes import AuthScheme
|
||||||
from .. import BaseTool
|
from .. import BaseTool
|
||||||
|
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||||
|
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||||
|
from ..openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
|
||||||
from ..tool_context import ToolContext
|
from ..tool_context import ToolContext
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -56,6 +57,7 @@ class IntegrationConnectorTool(BaseTool):
|
|||||||
'entity',
|
'entity',
|
||||||
'operation',
|
'operation',
|
||||||
'action',
|
'action',
|
||||||
|
'dynamic_auth_config',
|
||||||
]
|
]
|
||||||
|
|
||||||
OPTIONAL_FIELDS = [
|
OPTIONAL_FIELDS = [
|
||||||
@ -75,6 +77,8 @@ class IntegrationConnectorTool(BaseTool):
|
|||||||
operation: str,
|
operation: str,
|
||||||
action: str,
|
action: str,
|
||||||
rest_api_tool: RestApiTool,
|
rest_api_tool: RestApiTool,
|
||||||
|
auth_scheme: Optional[Union[AuthScheme, str]] = None,
|
||||||
|
auth_credential: Optional[Union[AuthCredential, str]] = None,
|
||||||
):
|
):
|
||||||
"""Initializes the ApplicationIntegrationTool.
|
"""Initializes the ApplicationIntegrationTool.
|
||||||
|
|
||||||
@ -108,6 +112,8 @@ class IntegrationConnectorTool(BaseTool):
|
|||||||
self._operation = operation
|
self._operation = operation
|
||||||
self._action = action
|
self._action = action
|
||||||
self._rest_api_tool = rest_api_tool
|
self._rest_api_tool = rest_api_tool
|
||||||
|
self._auth_scheme = auth_scheme
|
||||||
|
self._auth_credential = auth_credential
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_declaration(self) -> FunctionDeclaration:
|
def _get_declaration(self) -> FunctionDeclaration:
|
||||||
@ -126,10 +132,45 @@ class IntegrationConnectorTool(BaseTool):
|
|||||||
)
|
)
|
||||||
return function_decl
|
return function_decl
|
||||||
|
|
||||||
|
def _prepare_dynamic_euc(self, auth_credential: AuthCredential) -> str:
|
||||||
|
if (
|
||||||
|
auth_credential
|
||||||
|
and auth_credential.http
|
||||||
|
and auth_credential.http.credentials
|
||||||
|
and auth_credential.http.credentials.token
|
||||||
|
):
|
||||||
|
return auth_credential.http.credentials.token
|
||||||
|
return None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def run_async(
|
async def run_async(
|
||||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
tool_auth_handler = ToolAuthHandler.from_tool_context(
|
||||||
|
tool_context, self._auth_scheme, self._auth_credential
|
||||||
|
)
|
||||||
|
auth_result = tool_auth_handler.prepare_auth_credentials()
|
||||||
|
|
||||||
|
if auth_result.state == 'pending':
|
||||||
|
return {
|
||||||
|
'pending': True,
|
||||||
|
'message': 'Needs your authorization to access your data.',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Attach parameters from auth into main parameters list
|
||||||
|
if auth_result.auth_credential:
|
||||||
|
# Attach parameters from auth into main parameters list
|
||||||
|
auth_credential_token = self._prepare_dynamic_euc(
|
||||||
|
auth_result.auth_credential
|
||||||
|
)
|
||||||
|
if auth_credential_token:
|
||||||
|
args['dynamic_auth_config'] = {
|
||||||
|
'oauth2_auth_code_flow.access_token': auth_credential_token
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
args['dynamic_auth_config'] = {'oauth2_auth_code_flow.access_token': {}}
|
||||||
|
|
||||||
args['connection_name'] = self._connection_name
|
args['connection_name'] = self._connection_name
|
||||||
args['service_name'] = self._connection_service_name
|
args['service_name'] = self._connection_service_name
|
||||||
args['host'] = self._connection_host
|
args['host'] = self._connection_host
|
||||||
|
@ -16,9 +16,12 @@ import json
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
from fastapi.openapi.models import Operation
|
from fastapi.openapi.models import Operation
|
||||||
from google.adk.agents.readonly_context import ReadonlyContext
|
from google.adk.agents.readonly_context import ReadonlyContext
|
||||||
|
from google.adk.auth import AuthCredentialTypes
|
||||||
|
from google.adk.auth import OAuth2Auth
|
||||||
from google.adk.auth.auth_credential import AuthCredential
|
from google.adk.auth.auth_credential import AuthCredential
|
||||||
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
||||||
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||||
|
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
|
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||||
import pytest
|
import pytest
|
||||||
@ -162,6 +165,16 @@ def connection_details():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def connection_details_auth_override_enabled():
|
||||||
|
return {
|
||||||
|
"serviceName": "test-service",
|
||||||
|
"host": "test.host",
|
||||||
|
"name": "test-connection",
|
||||||
|
"authOverrideEnabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_initialization_with_integration_and_trigger(
|
async def test_initialization_with_integration_and_trigger(
|
||||||
project,
|
project,
|
||||||
@ -474,3 +487,139 @@ def test_initialization_with_connection_details(
|
|||||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
tool_name, tool_instructions
|
tool_name, tool_instructions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_with_connection_and_custom_auth(
|
||||||
|
mock_integration_client,
|
||||||
|
mock_connections_client,
|
||||||
|
mock_openapi_action_spec_parser,
|
||||||
|
connection_details_auth_override_enabled,
|
||||||
|
):
|
||||||
|
connection_name = "test-connection"
|
||||||
|
actions_list = ["create", "delete"]
|
||||||
|
tool_name = "My Actions Tool"
|
||||||
|
tool_instructions = "Perform actions using this tool."
|
||||||
|
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||||
|
connection_details_auth_override_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth2_data_google_cloud = {
|
||||||
|
"type": "oauth2",
|
||||||
|
"flows": {
|
||||||
|
"authorizationCode": {
|
||||||
|
"authorizationUrl": "https://test-url/o/oauth2/auth",
|
||||||
|
"tokenUrl": "https://test-url/token",
|
||||||
|
"scopes": {
|
||||||
|
"https://test-url/auth/test-scope": "test scope",
|
||||||
|
"https://www.test-url.com/auth/test-scope2": "test scope 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
|
||||||
|
|
||||||
|
auth_credential = AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id="test-client-id",
|
||||||
|
client_secret="test-client-secret",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
toolset = ApplicationIntegrationToolset(
|
||||||
|
project,
|
||||||
|
location,
|
||||||
|
connection=connection_name,
|
||||||
|
actions=actions_list,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_instructions=tool_instructions,
|
||||||
|
auth_scheme=oauth2_scheme,
|
||||||
|
auth_credential=auth_credential,
|
||||||
|
)
|
||||||
|
mock_integration_client.assert_called_once_with(
|
||||||
|
project, location, None, None, connection_name, None, actions_list, None
|
||||||
|
)
|
||||||
|
mock_connections_client.assert_called_once_with(
|
||||||
|
project, location, connection_name, None
|
||||||
|
)
|
||||||
|
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||||
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
|
tool_name, tool_instructions
|
||||||
|
)
|
||||||
|
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||||
|
assert len(toolset.get_tools()) == 1
|
||||||
|
assert toolset.get_tools()[0].name == "list_issues_operation"
|
||||||
|
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||||
|
assert toolset.get_tools()[0].action == "CustomAction"
|
||||||
|
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
|
||||||
|
assert toolset.get_tools()[0].auth_scheme == oauth2_scheme
|
||||||
|
assert toolset.get_tools()[0].auth_credential == auth_credential
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_with_connection_with_auth_override_disabled_and_custom_auth(
|
||||||
|
mock_integration_client,
|
||||||
|
mock_connections_client,
|
||||||
|
mock_openapi_action_spec_parser,
|
||||||
|
connection_details,
|
||||||
|
):
|
||||||
|
connection_name = "test-connection"
|
||||||
|
actions_list = ["create", "delete"]
|
||||||
|
tool_name = "My Actions Tool"
|
||||||
|
tool_instructions = "Perform actions using this tool."
|
||||||
|
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||||
|
connection_details
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth2_data_google_cloud = {
|
||||||
|
"type": "oauth2",
|
||||||
|
"flows": {
|
||||||
|
"authorizationCode": {
|
||||||
|
"authorizationUrl": "https://test-url/o/oauth2/auth",
|
||||||
|
"tokenUrl": "https://test-url/token",
|
||||||
|
"scopes": {
|
||||||
|
"https://test-url/auth/test-scope": "test scope",
|
||||||
|
"https://www.test-url.com/auth/test-scope2": "test scope 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth2_scheme = dict_to_auth_scheme(oauth2_data_google_cloud)
|
||||||
|
|
||||||
|
auth_credential = AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id="test-client-id",
|
||||||
|
client_secret="test-client-secret",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
toolset = ApplicationIntegrationToolset(
|
||||||
|
project,
|
||||||
|
location,
|
||||||
|
connection=connection_name,
|
||||||
|
actions=actions_list,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_instructions=tool_instructions,
|
||||||
|
auth_scheme=oauth2_scheme,
|
||||||
|
auth_credential=auth_credential,
|
||||||
|
)
|
||||||
|
mock_integration_client.assert_called_once_with(
|
||||||
|
project, location, None, None, connection_name, None, actions_list, None
|
||||||
|
)
|
||||||
|
mock_connections_client.assert_called_once_with(
|
||||||
|
project, location, connection_name, None
|
||||||
|
)
|
||||||
|
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||||
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
|
tool_name, tool_instructions
|
||||||
|
)
|
||||||
|
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||||
|
assert len(toolset.get_tools()) == 1
|
||||||
|
assert toolset.get_tools()[0].name == "list_issues_operation"
|
||||||
|
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||||
|
assert toolset.get_tools()[0].action == "CustomAction"
|
||||||
|
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
|
||||||
|
assert not toolset.get_tools()[0].auth_scheme
|
||||||
|
assert not toolset.get_tools()[0].auth_credential
|
||||||
|
@ -14,11 +14,14 @@
|
|||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
from google.adk.auth import AuthCredential
|
||||||
|
from google.adk.auth import AuthCredentialTypes
|
||||||
|
from google.adk.auth.auth_credential import HttpAuth
|
||||||
|
from google.adk.auth.auth_credential import HttpCredentials
|
||||||
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||||
from google.genai.types import FunctionDeclaration
|
from google.genai.types import FunctionDeclaration
|
||||||
from google.genai.types import Schema
|
from google.genai.types import Schema
|
||||||
from google.genai.types import Tool
|
|
||||||
from google.genai.types import Type
|
from google.genai.types import Type
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -67,6 +70,30 @@ def integration_tool(mock_rest_api_tool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def integration_tool_with_auth(mock_rest_api_tool):
|
||||||
|
"""Fixture for an IntegrationConnectorTool instance."""
|
||||||
|
return IntegrationConnectorTool(
|
||||||
|
name="test_integration_tool",
|
||||||
|
description="Test integration tool description.",
|
||||||
|
connection_name="test-conn",
|
||||||
|
connection_host="test.example.com",
|
||||||
|
connection_service_name="test-service",
|
||||||
|
entity="TestEntity",
|
||||||
|
operation="LIST",
|
||||||
|
action="TestAction",
|
||||||
|
rest_api_tool=mock_rest_api_tool,
|
||||||
|
auth_scheme=None,
|
||||||
|
auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.HTTP,
|
||||||
|
http=HttpAuth(
|
||||||
|
scheme="bearer",
|
||||||
|
credentials=HttpCredentials(token="mocked_token"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_declaration(integration_tool):
|
def test_get_declaration(integration_tool):
|
||||||
"""Tests the generation of the function declaration."""
|
"""Tests the generation of the function declaration."""
|
||||||
declaration = integration_tool._get_declaration()
|
declaration = integration_tool._get_declaration()
|
||||||
@ -123,3 +150,98 @@ async def test_run_async(integration_tool, mock_rest_api_tool):
|
|||||||
|
|
||||||
# Assert the result is what the mocked call returned
|
# Assert the result is what the mocked call returned
|
||||||
assert result == {"status": "success", "data": "mock_data"}
|
assert result == {"status": "success", "data": "mock_data"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_with_auth_async_none_token(
|
||||||
|
integration_tool_with_auth, mock_rest_api_tool
|
||||||
|
):
|
||||||
|
"""Tests run_async when auth credential token is None."""
|
||||||
|
input_args = {"user_id": "user456", "filter": "some_filter"}
|
||||||
|
expected_call_args = {
|
||||||
|
"user_id": "user456",
|
||||||
|
"filter": "some_filter",
|
||||||
|
"dynamic_auth_config": {"oauth2_auth_code_flow.access_token": {}},
|
||||||
|
"connection_name": "test-conn",
|
||||||
|
"service_name": "test-service",
|
||||||
|
"host": "test.example.com",
|
||||||
|
"entity": "TestEntity",
|
||||||
|
"operation": "LIST",
|
||||||
|
"action": "TestAction",
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||||
|
) as mock_from_tool_context:
|
||||||
|
mock_tool_auth_handler_instance = mock.MagicMock()
|
||||||
|
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||||
|
"done"
|
||||||
|
)
|
||||||
|
# Simulate an AuthCredential that would cause _prepare_dynamic_euc to return None
|
||||||
|
mock_auth_credential_without_token = AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.HTTP,
|
||||||
|
http=HttpAuth(
|
||||||
|
scheme="bearer",
|
||||||
|
credentials=HttpCredentials(token=None), # Token is None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = (
|
||||||
|
mock_auth_credential_without_token
|
||||||
|
)
|
||||||
|
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||||
|
|
||||||
|
result = await integration_tool_with_auth.run_async(
|
||||||
|
args=input_args, tool_context={}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_rest_api_tool.call.assert_called_once_with(
|
||||||
|
args=expected_call_args, tool_context={}
|
||||||
|
)
|
||||||
|
assert result == {"status": "success", "data": "mock_data"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_with_auth_async(
|
||||||
|
integration_tool_with_auth, mock_rest_api_tool
|
||||||
|
):
|
||||||
|
"""Tests the async execution with auth delegates correctly to the RestApiTool."""
|
||||||
|
input_args = {"user_id": "user123", "page_size": 10}
|
||||||
|
expected_call_args = {
|
||||||
|
"user_id": "user123",
|
||||||
|
"page_size": 10,
|
||||||
|
"dynamic_auth_config": {
|
||||||
|
"oauth2_auth_code_flow.access_token": "mocked_token"
|
||||||
|
},
|
||||||
|
"connection_name": "test-conn",
|
||||||
|
"service_name": "test-service",
|
||||||
|
"host": "test.example.com",
|
||||||
|
"entity": "TestEntity",
|
||||||
|
"operation": "LIST",
|
||||||
|
"action": "TestAction",
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||||
|
) as mock_from_tool_context:
|
||||||
|
mock_tool_auth_handler_instance = mock.MagicMock()
|
||||||
|
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||||
|
"done"
|
||||||
|
)
|
||||||
|
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||||
|
"done"
|
||||||
|
)
|
||||||
|
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.auth_credential = AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.HTTP,
|
||||||
|
http=HttpAuth(
|
||||||
|
scheme="bearer",
|
||||||
|
credentials=HttpCredentials(token="mocked_token"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||||
|
result = await integration_tool_with_auth.run_async(
|
||||||
|
args=input_args, tool_context={}
|
||||||
|
)
|
||||||
|
mock_rest_api_tool.call.assert_called_once_with(
|
||||||
|
args=expected_call_args, tool_context={}
|
||||||
|
)
|
||||||
|
assert result == {"status": "success", "data": "mock_data"}
|
||||||
|
Loading…
Reference in New Issue
Block a user