mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-21 12:52:18 -06:00
ADK changes
PiperOrigin-RevId: 750763037
This commit is contained in:
committed by
Copybara-Service
parent
a49d339251
commit
ca993277de
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .application_integration_toolset import ApplicationIntegrationToolset
|
||||
from .integration_connector_tool import IntegrationConnectorTool
|
||||
|
||||
__all__ = [
|
||||
'ApplicationIntegrationToolset',
|
||||
'IntegrationConnectorTool',
|
||||
]
|
||||
|
||||
@@ -12,21 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_credential import AuthCredentialTypes
|
||||
from ...auth.auth_credential import ServiceAccount
|
||||
from ...auth.auth_credential import ServiceAccountCredential
|
||||
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from .clients.connections_client import ConnectionsClient
|
||||
from .clients.integration_client import IntegrationClient
|
||||
from .integration_connector_tool import IntegrationConnectorTool
|
||||
|
||||
|
||||
# TODO(cheliu): Apply a common toolset interface
|
||||
@@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
|
||||
actions,
|
||||
service_account_json,
|
||||
)
|
||||
connection_details = {}
|
||||
if integration and trigger:
|
||||
spec = integration_client.get_openapi_spec_for_integration()
|
||||
elif connection and (entity_operations or actions):
|
||||
@@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
|
||||
project, location, connection, service_account_json
|
||||
)
|
||||
connection_details = connections_client.get_connection_details()
|
||||
tool_instructions += (
|
||||
"ALWAYS use serviceName = "
|
||||
+ connection_details["serviceName"]
|
||||
+ ", host = "
|
||||
+ connection_details["host"]
|
||||
+ " and the connection name = "
|
||||
+ f"projects/{project}/locations/{location}/connections/{connection} when"
|
||||
" using this tool"
|
||||
+ ". DONOT ask the user for these values as you already have those."
|
||||
)
|
||||
spec = integration_client.get_openapi_spec_for_connection(
|
||||
tool_name,
|
||||
tool_instructions,
|
||||
@@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
|
||||
"Either (integration and trigger) or (connection and"
|
||||
" (entity_operations or actions)) should be provided."
|
||||
)
|
||||
self._parse_spec_to_tools(spec)
|
||||
self._parse_spec_to_tools(spec, connection_details)
|
||||
|
||||
def _parse_spec_to_tools(self, spec_dict):
|
||||
def _parse_spec_to_tools(self, spec_dict, connection_details):
|
||||
"""Parses the spec dict to a list of RestApiTool."""
|
||||
if self.service_account_json:
|
||||
sa_credential = ServiceAccountCredential.model_validate_json(
|
||||
@@ -218,12 +209,43 @@ class ApplicationIntegrationToolset:
|
||||
),
|
||||
)
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=auth_credential,
|
||||
auth_scheme=auth_scheme,
|
||||
).get_tools()
|
||||
for tool in tools:
|
||||
|
||||
if self.integration and self.trigger:
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=auth_credential,
|
||||
auth_scheme=auth_scheme,
|
||||
).get_tools()
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
return
|
||||
|
||||
operations = OpenApiSpecParser().parse(spec_dict)
|
||||
|
||||
for open_api_operation in operations:
|
||||
operation = getattr(open_api_operation.operation, "x-operation")
|
||||
entity = None
|
||||
action = None
|
||||
if hasattr(open_api_operation.operation, "x-entity"):
|
||||
entity = getattr(open_api_operation.operation, "x-entity")
|
||||
elif hasattr(open_api_operation.operation, "x-action"):
|
||||
action = getattr(open_api_operation.operation, "x-action")
|
||||
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
|
||||
if auth_scheme:
|
||||
rest_api_tool.configure_auth_scheme(auth_scheme)
|
||||
if auth_credential:
|
||||
rest_api_tool.configure_auth_credential(auth_credential)
|
||||
tool = IntegrationConnectorTool(
|
||||
name=rest_api_tool.name,
|
||||
description=rest_api_tool.description,
|
||||
connection_name=connection_details["name"],
|
||||
connection_host=connection_details["host"],
|
||||
connection_service_name=connection_details["serviceName"],
|
||||
entity=entity,
|
||||
action=action,
|
||||
operation=operation,
|
||||
rest_api_tool=rest_api_tool,
|
||||
)
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
|
||||
@@ -68,12 +68,14 @@ class ConnectionsClient:
|
||||
response = self._execute_api_call(url)
|
||||
|
||||
connection_data = response.json()
|
||||
connection_name = connection_data.get("name", "")
|
||||
service_name = connection_data.get("serviceDirectory", "")
|
||||
host = connection_data.get("host", "")
|
||||
if host:
|
||||
service_name = connection_data.get("tlsServiceDirectory", "")
|
||||
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
||||
return {
|
||||
"name": connection_name,
|
||||
"serviceName": service_name,
|
||||
"host": host,
|
||||
"authOverrideEnabled": auth_override_enabled,
|
||||
@@ -291,13 +293,9 @@ class ConnectionsClient:
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
description = (
|
||||
f"Use this tool with" f' action = "{action}" and'
|
||||
) + f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
description = f"Use this tool to execute {action}"
|
||||
if operation == "EXECUTE_QUERY":
|
||||
description = (
|
||||
(f"Use this tool with" f' action = "{action}" and')
|
||||
+ f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
description += (
|
||||
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
||||
" different value otherwise. If user provides a query in natural"
|
||||
" language, convert it to SQL query and then execute it using the"
|
||||
@@ -308,6 +306,8 @@ class ConnectionsClient:
|
||||
"summary": f"{action_display_name}",
|
||||
"description": f"{description} {tool_instructions}",
|
||||
"operationId": f"{tool_name}_{action_display_name}",
|
||||
"x-action": f"{action}",
|
||||
"x-operation": f"{operation}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
@@ -347,16 +347,12 @@ class ConnectionsClient:
|
||||
"post": {
|
||||
"summary": f"List {entity}",
|
||||
"description": (
|
||||
f"Returns all entities of type {entity}. Use this tool with"
|
||||
+ f' entity = "{entity}" and'
|
||||
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
|
||||
" from"
|
||||
+ ' user. Always use ""'
|
||||
+ ' as filter clause and ""'
|
||||
+ " as page token and 50 as page size until user specifies a"
|
||||
" different value otherwise. Use single quotes for strings in"
|
||||
f" filter clause. {tool_instructions}"
|
||||
f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
|
||||
following format: `field_name1='value1' AND field_name2='value2'
|
||||
`. {tool_instructions}"""
|
||||
),
|
||||
"x-operation": "LIST_ENTITIES",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_list_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
@@ -401,14 +397,11 @@ class ConnectionsClient:
|
||||
"post": {
|
||||
"summary": f"Get {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Returns the details of the {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
|
||||
f" user. {tool_instructions}"
|
||||
f"Returns the details of the {entity}. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_get_{entity}",
|
||||
"x-operation": "GET_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
@@ -445,17 +438,10 @@ class ConnectionsClient:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Create {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Creates a new entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Follow the schema of the entity provided in the"
|
||||
f" instructions to create {entity}. {tool_instructions}"
|
||||
),
|
||||
"summary": f"Creates a new {entity}",
|
||||
"description": f"Creates a new {entity}. {tool_instructions}",
|
||||
"x-operation": "CREATE_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_create_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
@@ -491,18 +477,10 @@ class ConnectionsClient:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Update {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Updates an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Use entityId to uniquely identify the entity to"
|
||||
" update. Follow the schema of the entity provided in the"
|
||||
f" instructions to update {entity}. {tool_instructions}"
|
||||
),
|
||||
"summary": f"Updates the {entity}",
|
||||
"description": f"Updates the {entity}. {tool_instructions}",
|
||||
"x-operation": "UPDATE_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_update_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
@@ -538,16 +516,10 @@ class ConnectionsClient:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Delete {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Deletes an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
f" user. {tool_instructions}"
|
||||
),
|
||||
"summary": f"Delete the {entity}",
|
||||
"description": f"Deletes the {entity}. {tool_instructions}",
|
||||
"x-operation": "DELETE_ENTITY",
|
||||
"x-entity": f"{entity}",
|
||||
"operationId": f"{tool_name}_delete_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import BaseTool
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntegrationConnectorTool(BaseTool):
|
||||
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
|
||||
|
||||
This tool adds Application Integration specific context like connection
|
||||
details, entity, operation, and action to the underlying REST API call
|
||||
handled by RestApiTool. It prepares the arguments and then delegates the
|
||||
actual API call execution to the contained RestApiTool instance.
|
||||
|
||||
* Generates request params and body
|
||||
* Attaches auth credentials to API call.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
"""
|
||||
|
||||
EXCLUDE_FIELDS = [
|
||||
'connection_name',
|
||||
'service_name',
|
||||
'host',
|
||||
'entity',
|
||||
'operation',
|
||||
'action',
|
||||
]
|
||||
|
||||
OPTIONAL_FIELDS = [
|
||||
'page_size',
|
||||
'page_token',
|
||||
'filter',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
connection_name: str,
|
||||
connection_host: str,
|
||||
connection_service_name: str,
|
||||
entity: str,
|
||||
operation: str,
|
||||
action: str,
|
||||
rest_api_tool: RestApiTool,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationTool.
|
||||
|
||||
Args:
|
||||
name: The name of the tool, typically derived from the API operation.
|
||||
Should be unique and adhere to Gemini function naming conventions
|
||||
(e.g., less than 64 characters).
|
||||
description: A description of what the tool does, usually based on the
|
||||
API operation's summary or description.
|
||||
connection_name: The name of the Integration Connector connection.
|
||||
connection_host: The hostname or IP address for the connection.
|
||||
connection_service_name: The specific service name within the host.
|
||||
entity: The Integration Connector entity being targeted.
|
||||
operation: The specific operation being performed on the entity.
|
||||
action: The action associated with the operation (e.g., 'execute').
|
||||
rest_api_tool: An initialized RestApiTool instance that handles the
|
||||
underlying REST API communication based on an OpenAPI specification
|
||||
operation. This tool will be called by ApplicationIntegrationTool with
|
||||
added connection and context arguments. tool =
|
||||
[RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
"""
|
||||
# Gemini restrict the length of function name to be less than 64 characters
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
self.connection_name = connection_name
|
||||
self.connection_host = connection_host
|
||||
self.connection_service_name = connection_service_name
|
||||
self.entity = entity
|
||||
self.operation = operation
|
||||
self.action = action
|
||||
self.rest_api_tool = rest_api_tool
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
|
||||
for field in self.EXCLUDE_FIELDS:
|
||||
if field in schema_dict['properties']:
|
||||
del schema_dict['properties'][field]
|
||||
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
|
||||
if field in schema_dict['required']:
|
||||
schema_dict['required'].remove(field)
|
||||
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
args['connection_name'] = self.connection_name
|
||||
args['service_name'] = self.connection_service_name
|
||||
args['host'] = self.connection_host
|
||||
args['entity'] = self.entity
|
||||
args['operation'] = self.operation
|
||||
args['action'] = self.action
|
||||
logger.info('Running tool: %s with args: %s', self.name, args)
|
||||
return self.rest_api_tool.call(args=args, tool_context=tool_context)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'ApplicationIntegrationTool(name="{self.name}",'
|
||||
f' description="{self.description}",'
|
||||
f' connection_name="{self.connection_name}", entity="{self.entity}",'
|
||||
f' operation="{self.operation}", action="{self.action}")'
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'ApplicationIntegrationTool(name="{self.name}",'
|
||||
f' description="{self.description}",'
|
||||
f' connection_name="{self.connection_name}",'
|
||||
f' connection_host="{self.connection_host}",'
|
||||
f' connection_service_name="{self.connection_service_name}",'
|
||||
f' entity="{self.entity}", operation="{self.operation}",'
|
||||
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
|
||||
)
|
||||
Reference in New Issue
Block a user