ADK changes

PiperOrigin-RevId: 750763037
This commit is contained in:
Google Team Member 2025-04-23 16:28:22 -07:00 committed by Copybara-Service
parent a49d339251
commit ca993277de
6 changed files with 435 additions and 105 deletions

View File

@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
from .application_integration_toolset import ApplicationIntegrationToolset from .application_integration_toolset import ApplicationIntegrationToolset
from .integration_connector_tool import IntegrationConnectorTool
__all__ = [ __all__ = [
'ApplicationIntegrationToolset', 'ApplicationIntegrationToolset',
'IntegrationConnectorTool',
] ]

View File

@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict, List, Optional
from typing import List
from typing import Optional
from fastapi.openapi.models import HTTPBearer from fastapi.openapi.models import HTTPBearer
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ...auth.auth_credential import AuthCredential from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential from ...auth.auth_credential import ServiceAccountCredential
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.connections_client import ConnectionsClient
from .clients.integration_client import IntegrationClient
from .integration_connector_tool import IntegrationConnectorTool
# TODO(cheliu): Apply a common toolset interface # TODO(cheliu): Apply a common toolset interface
@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
actions, actions,
service_account_json, service_account_json,
) )
connection_details = {}
if integration and trigger: if integration and trigger:
spec = integration_client.get_openapi_spec_for_integration() spec = integration_client.get_openapi_spec_for_integration()
elif connection and (entity_operations or actions): elif connection and (entity_operations or actions):
@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
project, location, connection, service_account_json project, location, connection, service_account_json
) )
connection_details = connections_client.get_connection_details() connection_details = connections_client.get_connection_details()
tool_instructions += (
"ALWAYS use serviceName = "
+ connection_details["serviceName"]
+ ", host = "
+ connection_details["host"]
+ " and the connection name = "
+ f"projects/{project}/locations/{location}/connections/{connection} when"
" using this tool"
+ ". DONOT ask the user for these values as you already have those."
)
spec = integration_client.get_openapi_spec_for_connection( spec = integration_client.get_openapi_spec_for_connection(
tool_name, tool_name,
tool_instructions, tool_instructions,
@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
"Either (integration and trigger) or (connection and" "Either (integration and trigger) or (connection and"
" (entity_operations or actions)) should be provided." " (entity_operations or actions)) should be provided."
) )
self._parse_spec_to_tools(spec) self._parse_spec_to_tools(spec, connection_details)
def _parse_spec_to_tools(self, spec_dict): def _parse_spec_to_tools(self, spec_dict, connection_details):
"""Parses the spec dict to a list of RestApiTool.""" """Parses the spec dict to a list of RestApiTool."""
if self.service_account_json: if self.service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json( sa_credential = ServiceAccountCredential.model_validate_json(
@ -218,6 +209,8 @@ class ApplicationIntegrationToolset:
), ),
) )
auth_scheme = HTTPBearer(bearerFormat="JWT") auth_scheme = HTTPBearer(bearerFormat="JWT")
if self.integration and self.trigger:
tools = OpenAPIToolset( tools = OpenAPIToolset(
spec_dict=spec_dict, spec_dict=spec_dict,
auth_credential=auth_credential, auth_credential=auth_credential,
@ -225,6 +218,35 @@ class ApplicationIntegrationToolset:
).get_tools() ).get_tools()
for tool in tools: for tool in tools:
self.generated_tools[tool.name] = tool self.generated_tools[tool.name] = tool
return
operations = OpenApiSpecParser().parse(spec_dict)
for open_api_operation in operations:
operation = getattr(open_api_operation.operation, "x-operation")
entity = None
action = None
if hasattr(open_api_operation.operation, "x-entity"):
entity = getattr(open_api_operation.operation, "x-entity")
elif hasattr(open_api_operation.operation, "x-action"):
action = getattr(open_api_operation.operation, "x-action")
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
if auth_scheme:
rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential)
tool = IntegrationConnectorTool(
name=rest_api_tool.name,
description=rest_api_tool.description,
connection_name=connection_details["name"],
connection_host=connection_details["host"],
connection_service_name=connection_details["serviceName"],
entity=entity,
action=action,
operation=operation,
rest_api_tool=rest_api_tool,
)
self.generated_tools[tool.name] = tool
def get_tools(self) -> List[RestApiTool]: def get_tools(self) -> List[RestApiTool]:
return list(self.generated_tools.values()) return list(self.generated_tools.values())

View File

@ -68,12 +68,14 @@ class ConnectionsClient:
response = self._execute_api_call(url) response = self._execute_api_call(url)
connection_data = response.json() connection_data = response.json()
connection_name = connection_data.get("name", "")
service_name = connection_data.get("serviceDirectory", "") service_name = connection_data.get("serviceDirectory", "")
host = connection_data.get("host", "") host = connection_data.get("host", "")
if host: if host:
service_name = connection_data.get("tlsServiceDirectory", "") service_name = connection_data.get("tlsServiceDirectory", "")
auth_override_enabled = connection_data.get("authOverrideEnabled", False) auth_override_enabled = connection_data.get("authOverrideEnabled", False)
return { return {
"name": connection_name,
"serviceName": service_name, "serviceName": service_name,
"host": host, "host": host,
"authOverrideEnabled": auth_override_enabled, "authOverrideEnabled": auth_override_enabled,
@ -291,13 +293,9 @@ class ConnectionsClient:
tool_name: str = "", tool_name: str = "",
tool_instructions: str = "", tool_instructions: str = "",
) -> Dict[str, Any]: ) -> Dict[str, Any]:
description = ( description = f"Use this tool to execute {action}"
f"Use this tool with" f' action = "{action}" and'
) + f' operation = "{operation}" only. Dont ask these values from user.'
if operation == "EXECUTE_QUERY": if operation == "EXECUTE_QUERY":
description = ( description += (
(f"Use this tool with" f' action = "{action}" and')
+ f' operation = "{operation}" only. Dont ask these values from user.'
" Use pageSize = 50 and timeout = 120 until user specifies a" " Use pageSize = 50 and timeout = 120 until user specifies a"
" different value otherwise. If user provides a query in natural" " different value otherwise. If user provides a query in natural"
" language, convert it to SQL query and then execute it using the" " language, convert it to SQL query and then execute it using the"
@ -308,6 +306,8 @@ class ConnectionsClient:
"summary": f"{action_display_name}", "summary": f"{action_display_name}",
"description": f"{description} {tool_instructions}", "description": f"{description} {tool_instructions}",
"operationId": f"{tool_name}_{action_display_name}", "operationId": f"{tool_name}_{action_display_name}",
"x-action": f"{action}",
"x-operation": f"{operation}",
"requestBody": { "requestBody": {
"content": { "content": {
"application/json": { "application/json": {
@ -347,16 +347,12 @@ class ConnectionsClient:
"post": { "post": {
"summary": f"List {entity}", "summary": f"List {entity}",
"description": ( "description": (
f"Returns all entities of type {entity}. Use this tool with" f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
+ f' entity = "{entity}" and' following format: `field_name1='value1' AND field_name2='value2'
+ ' operation = "LIST_ENTITIES" only. Dont ask these values' `. {tool_instructions}"""
" from"
+ ' user. Always use ""'
+ ' as filter clause and ""'
+ " as page token and 50 as page size until user specifies a"
" different value otherwise. Use single quotes for strings in"
f" filter clause. {tool_instructions}"
), ),
"x-operation": "LIST_ENTITIES",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_list_{entity}", "operationId": f"{tool_name}_list_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
@ -401,14 +397,11 @@ class ConnectionsClient:
"post": { "post": {
"summary": f"Get {entity}", "summary": f"Get {entity}",
"description": ( "description": (
( f"Returns the details of the {entity}. {tool_instructions}"
f"Returns the details of the {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
f" user. {tool_instructions}"
), ),
"operationId": f"{tool_name}_get_{entity}", "operationId": f"{tool_name}_get_{entity}",
"x-operation": "GET_ENTITY",
"x-entity": f"{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
"application/json": { "application/json": {
@ -445,17 +438,10 @@ class ConnectionsClient:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"post": { "post": {
"summary": f"Create {entity}", "summary": f"Creates a new {entity}",
"description": ( "description": f"Creates a new {entity}. {tool_instructions}",
( "x-operation": "CREATE_ENTITY",
f"Creates a new entity of type {entity}. Use this tool with" "x-entity": f"{entity}",
f' entity = "{entity}" and'
)
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Follow the schema of the entity provided in the"
f" instructions to create {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_create_{entity}", "operationId": f"{tool_name}_create_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
@ -491,18 +477,10 @@ class ConnectionsClient:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"post": { "post": {
"summary": f"Update {entity}", "summary": f"Updates the {entity}",
"description": ( "description": f"Updates the {entity}. {tool_instructions}",
( "x-operation": "UPDATE_ENTITY",
f"Updates an entity of type {entity}. Use this tool with" "x-entity": f"{entity}",
f' entity = "{entity}" and'
)
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Use entityId to uniquely identify the entity to"
" update. Follow the schema of the entity provided in the"
f" instructions to update {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_update_{entity}", "operationId": f"{tool_name}_update_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
@ -538,16 +516,10 @@ class ConnectionsClient:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"post": { "post": {
"summary": f"Delete {entity}", "summary": f"Delete the {entity}",
"description": ( "description": f"Deletes the {entity}. {tool_instructions}",
( "x-operation": "DELETE_ENTITY",
f"Deletes an entity of type {entity}. Use this tool with" "x-entity": f"{entity}",
f' entity = "{entity}" and'
)
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
" from"
f" user. {tool_instructions}"
),
"operationId": f"{tool_name}_delete_{entity}", "operationId": f"{tool_name}_delete_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {

View File

@ -0,0 +1,159 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .. import BaseTool
from ..tool_context import ToolContext
logger = logging.getLogger(__name__)
class IntegrationConnectorTool(BaseTool):
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
This tool adds Application Integration specific context like connection
details, entity, operation, and action to the underlying REST API call
handled by RestApiTool. It prepares the arguments and then delegates the
actual API call execution to the contained RestApiTool instance.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
EXCLUDE_FIELDS = [
'connection_name',
'service_name',
'host',
'entity',
'operation',
'action',
]
OPTIONAL_FIELDS = [
'page_size',
'page_token',
'filter',
]
def __init__(
self,
name: str,
description: str,
connection_name: str,
connection_host: str,
connection_service_name: str,
entity: str,
operation: str,
action: str,
rest_api_tool: RestApiTool,
):
"""Initializes the ApplicationIntegrationTool.
Args:
name: The name of the tool, typically derived from the API operation.
Should be unique and adhere to Gemini function naming conventions
(e.g., less than 64 characters).
description: A description of what the tool does, usually based on the
API operation's summary or description.
connection_name: The name of the Integration Connector connection.
connection_host: The hostname or IP address for the connection.
connection_service_name: The specific service name within the host.
entity: The Integration Connector entity being targeted.
operation: The specific operation being performed on the entity.
action: The action associated with the operation (e.g., 'execute').
rest_api_tool: An initialized RestApiTool instance that handles the
underlying REST API communication based on an OpenAPI specification
operation. This tool will be called by ApplicationIntegrationTool with
added connection and context arguments. tool =
[RestApiTool.from_parsed_operation(o) for o in operations]
"""
# Gemini restrict the length of function name to be less than 64 characters
super().__init__(
name=name,
description=description,
)
self.connection_name = connection_name
self.connection_host = connection_host
self.connection_service_name = connection_service_name
self.entity = entity
self.operation = operation
self.action = action
self.rest_api_tool = rest_api_tool
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
if field in schema_dict['properties']:
del schema_dict['properties'][field]
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
if field in schema_dict['required']:
schema_dict['required'].remove(field)
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
args['connection_name'] = self.connection_name
args['service_name'] = self.connection_service_name
args['host'] = self.connection_host
args['entity'] = self.entity
args['operation'] = self.operation
args['action'] = self.action
logger.info('Running tool: %s with args: %s', self.name, args)
return self.rest_api_tool.call(args=args, tool_context=tool_context)
def __str__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}", entity="{self.entity}",'
f' operation="{self.operation}", action="{self.action}")'
)
def __repr__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}",'
f' connection_host="{self.connection_host}",'
f' connection_service_name="{self.connection_service_name}",'
f' entity="{self.entity}", operation="{self.operation}",'
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
)

View File

@ -14,10 +14,12 @@
import json import json
from unittest import mock from unittest import mock
from fastapi.openapi.models import Operation
from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest import pytest
@ -50,6 +52,59 @@ def mock_openapi_toolset():
yield mock_toolset yield mock_toolset
def get_mocked_parsed_operation(operation_id, attributes):
mock_openapi_spec_parser_instance = mock.MagicMock()
mock_parsed_operation = mock.MagicMock(spec=ParsedOperation)
mock_parsed_operation.name = "list_issues"
mock_parsed_operation.description = "list_issues_description"
mock_parsed_operation.endpoint = OperationEndpoint(
base_url="http://localhost:8080",
path="/v1/issues",
method="GET",
)
mock_parsed_operation.auth_scheme = None
mock_parsed_operation.auth_credential = None
mock_parsed_operation.additional_context = {}
mock_parsed_operation.parameters = []
mock_operation = mock.MagicMock(spec=Operation)
mock_operation.operationId = operation_id
mock_operation.description = "list_issues_description"
mock_operation.parameters = []
mock_operation.requestBody = None
mock_operation.responses = {}
mock_operation.callbacks = {}
for key, value in attributes.items():
setattr(mock_operation, key, value)
mock_parsed_operation.operation = mock_operation
mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation]
return mock_openapi_spec_parser_instance
@pytest.fixture
def mock_openapi_entity_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_spec_parser_instance = get_mocked_parsed_operation(
"list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"}
)
mock_spec_parser.return_value = mock_openapi_spec_parser_instance
yield mock_spec_parser
@pytest.fixture
def mock_openapi_action_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation(
"list_issues_operation",
{"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"},
)
mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance
yield mock_spec_parser
@pytest.fixture @pytest.fixture
def project(): def project():
return "test-project" return "test-project"
@ -72,7 +127,11 @@ def connection_spec():
@pytest.fixture @pytest.fixture
def connection_details(): def connection_details():
return {"serviceName": "test-service", "host": "test.host"} return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
}
def test_initialization_with_integration_and_trigger( def test_initialization_with_integration_and_trigger(
@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations(
location, location,
mock_integration_client, mock_integration_client,
mock_connections_client, mock_connections_client,
mock_openapi_toolset, mock_openapi_entity_spec_parser,
connection_details, connection_details,
): ):
connection_name = "test-connection" connection_name = "test-connection"
@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations(
mock_connections_client.assert_called_once_with( mock_connections_client.assert_called_once_with(
project, location, connection_name, None project, location, connection_name, None
) )
mock_openapi_entity_spec_parser.return_value.parse.assert_called_once()
mock_connections_client.return_value.get_connection_details.assert_called_once() mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name,
tool_instructions tool_instructions,
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
) )
mock_openapi_toolset.assert_called_once()
assert len(toolset.get_tools()) == 1 assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool" assert toolset.get_tools()[0].name == "list_issues"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].entity == "Issues"
assert toolset.get_tools()[0].operation == "LIST_ENTITIES"
def test_initialization_with_connection_and_actions( def test_initialization_with_connection_and_actions(
@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions(
location, location,
mock_integration_client, mock_integration_client,
mock_connections_client, mock_connections_client,
mock_openapi_toolset, mock_openapi_action_spec_parser,
connection_details, connection_details,
): ):
connection_name = "test-connection" connection_name = "test-connection"
@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name,
tool_instructions tool_instructions
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
) )
mock_openapi_toolset.assert_called_once() mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1 assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool" assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
def test_initialization_without_required_params(project, location): def test_initialization_without_required_params(project, location):
@ -337,9 +392,4 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name,
tool_instructions tool_instructions
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
" connection name ="
" projects/test-project/locations/us-central1/connections/test-connection"
" when using this tool. DONOT ask the user for these values as you"
" already have those.",
) )

View File

@ -0,0 +1,125 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type
import pytest
@pytest.fixture
def mock_rest_api_tool():
"""Fixture for a mocked RestApiTool."""
mock_tool = mock.MagicMock(spec=RestApiTool)
mock_tool.name = "mock_rest_tool"
mock_tool.description = "Mock REST tool description."
# Mock the internal parser needed for _get_declaration
mock_parser = mock.MagicMock()
mock_parser.get_json_schema.return_value = {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "User ID"},
"connection_name": {"type": "string"},
"host": {"type": "string"},
"service_name": {"type": "string"},
"entity": {"type": "string"},
"operation": {"type": "string"},
"action": {"type": "string"},
"page_size": {"type": "integer"},
"filter": {"type": "string"},
},
"required": ["user_id", "page_size", "filter", "connection_name"],
}
mock_tool._operation_parser = mock_parser
mock_tool.call.return_value = {"status": "success", "data": "mock_data"}
return mock_tool
@pytest.fixture
def integration_tool(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
)
def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_integration_tool"
assert declaration.description == "Test integration tool description."
# Check parameters schema
params = declaration.parameters
assert isinstance(params, Schema)
print(f"params: {params}")
assert params.type == Type.OBJECT
# Check properties (excluded fields should not be present)
assert "user_id" in params.properties
assert "connection_name" not in params.properties
assert "host" not in params.properties
assert "service_name" not in params.properties
assert "entity" not in params.properties
assert "operation" not in params.properties
assert "action" not in params.properties
assert "page_size" in params.properties
assert "filter" in params.properties
# Check required fields (optional and excluded fields should not be required)
assert "user_id" in params.required
assert "page_size" not in params.required
assert "filter" not in params.required
assert "connection_name" not in params.required
@pytest.mark.asyncio
async def test_run_async(integration_tool, mock_rest_api_tool):
"""Tests the async execution delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"connection_name": "test-conn",
"host": "test.example.com",
"service_name": "test-service",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
result = await integration_tool.run_async(args=input_args, tool_context=None)
# Assert the underlying rest_api_tool.call was called correctly
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context=None
)
# Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"}