refactor: refactor application integration toolset to hide non-public field

PiperOrigin-RevId: 758469938
This commit is contained in:
Xiang (Sean) Zhou 2025-05-13 19:27:30 -07:00 committed by Copybara-Service
parent 00e0035c8a
commit 14cf910ce6
4 changed files with 55 additions and 53 deletions

View File

@ -33,7 +33,7 @@ jira_toolset = ApplicationIntegrationToolset(
location=connection_location,
connection=connection_name,
entity_operations={"Issues": [], "Projects": []},
tool_name="jira_issue_manager",
tool_name_prefix="jira_issue_manager",
)
root_agent = LlmAgent(

View File

@ -86,7 +86,7 @@ class ApplicationIntegrationToolset:
actions: Optional[str] = None,
# Optional parameter for the toolset. This is prepended to the generated
# tool/python function name.
tool_name: Optional[str] = "",
tool_name_prefix: Optional[str] = "",
# Optional parameter for the toolset. This is appended to the generated
# tool/python function description.
tool_instructions: Optional[str] = "",
@ -103,7 +103,7 @@ class ApplicationIntegrationToolset:
connection: The connection name.
entity_operations: The entity operations supported by the connection.
actions: The actions supported by the connection.
tool_name: The name of the tool.
tool_name_prefix: The name prefix of the generated tools.
tool_instructions: The instructions for the tool.
service_account_json: The service account configuration as a dictionary.
Required if not using default service credential. Used for fetching
@ -122,15 +122,15 @@ class ApplicationIntegrationToolset:
"""
self.project = project
self.location = location
self.integration = integration
self.triggers = triggers
self.connection = connection
self.entity_operations = entity_operations
self.actions = actions
self.tool_name = tool_name
self.tool_instructions = tool_instructions
self.service_account_json = service_account_json
self._tool_filter = tool_filter
self._integration = integration
self._triggers = triggers
self._connection = connection
self._entity_operations = entity_operations
self._actions = actions
self._tool_name_prefix = tool_name_prefix
self._tool_instructions = tool_instructions
self._service_account_json = service_account_json
self.tool_filter = tool_filter
integration_client = IntegrationClient(
project,
@ -151,7 +151,7 @@ class ApplicationIntegrationToolset:
)
connection_details = connections_client.get_connection_details()
spec = integration_client.get_openapi_spec_for_connection(
tool_name,
tool_name_prefix,
tool_instructions,
)
else:
@ -159,15 +159,15 @@ class ApplicationIntegrationToolset:
"Invalid request, Either integration or (connection and"
" (entity_operations or actions)) should be provided."
)
self.openapi_toolset = None
self.tool = None
self._openapi_toolset = None
self._tool = None
self._parse_spec_to_toolset(spec, connection_details)
def _parse_spec_to_toolset(self, spec_dict, connection_details):
"""Parses the spec dict to OpenAPI toolset."""
if self.service_account_json:
if self._service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json(
self.service_account_json
self._service_account_json
)
service_account = ServiceAccount(
service_account_credential=sa_credential,
@ -186,12 +186,12 @@ class ApplicationIntegrationToolset:
)
auth_scheme = HTTPBearer(bearerFormat="JWT")
if self.integration:
self.openapi_toolset = OpenAPIToolset(
if self._integration:
self._openapi_toolset = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=auth_credential,
auth_scheme=auth_scheme,
tool_filter=self._tool_filter,
tool_filter=self.tool_filter,
)
return
@ -210,7 +210,7 @@ class ApplicationIntegrationToolset:
rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential)
self.tool = IntegrationConnectorTool(
self._tool = IntegrationConnectorTool(
name=rest_api_tool.name,
description=rest_api_tool.description,
connection_name=connection_details["name"],
@ -224,9 +224,11 @@ class ApplicationIntegrationToolset:
@override
async def get_tools(self) -> List[RestApiTool]:
return [self.tool] if self.tool else await self.openapi_toolset.get_tools()
return (
[self._tool] if self._tool else await self._openapi_toolset.get_tools()
)
@override
async def close(self) -> None:
if self.openapi_toolset:
await self.openapi_toolset.close()
if self._openapi_toolset:
await self._openapi_toolset.close()

View File

@ -101,18 +101,18 @@ class IntegrationConnectorTool(BaseTool):
name=name,
description=description,
)
self.connection_name = connection_name
self.connection_host = connection_host
self.connection_service_name = connection_service_name
self.entity = entity
self.operation = operation
self.action = action
self.rest_api_tool = rest_api_tool
self._connection_name = connection_name
self._connection_host = connection_host
self._connection_service_name = connection_service_name
self._entity = entity
self._operation = operation
self._action = action
self._rest_api_tool = rest_api_tool
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
schema_dict = self._rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
if field in schema_dict['properties']:
del schema_dict['properties'][field]
@ -130,30 +130,30 @@ class IntegrationConnectorTool(BaseTool):
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
args['connection_name'] = self.connection_name
args['service_name'] = self.connection_service_name
args['host'] = self.connection_host
args['entity'] = self.entity
args['operation'] = self.operation
args['action'] = self.action
args['connection_name'] = self._connection_name
args['service_name'] = self._connection_service_name
args['host'] = self._connection_host
args['entity'] = self._entity
args['operation'] = self._operation
args['action'] = self._action
logger.info('Running tool: %s with args: %s', self.name, args)
return self.rest_api_tool.call(args=args, tool_context=tool_context)
return self._rest_api_tool.call(args=args, tool_context=tool_context)
def __str__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}", entity="{self.entity}",'
f' operation="{self.operation}", action="{self.action}")'
f' connection_name="{self._connection_name}", entity="{self._entity}",'
f' operation="{self._operation}", action="{self._action}")'
)
def __repr__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}",'
f' connection_host="{self.connection_host}",'
f' connection_service_name="{self.connection_service_name}",'
f' entity="{self.entity}", operation="{self.operation}",'
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
f' connection_name="{self._connection_name}",'
f' connection_host="{self._connection_host}",'
f' connection_service_name="{self._connection_service_name}",'
f' entity="{self._entity}", operation="{self._operation}",'
f' action="{self._action}", rest_api_tool={repr(self._rest_api_tool)})'
)

View File

@ -262,7 +262,7 @@ async def test_initialization_with_connection_and_entity_operations(
location,
connection=connection_name,
entity_operations=entity_operations_list,
tool_name=tool_name,
tool_name_prefix=tool_name,
tool_instructions=tool_instructions,
)
mock_integration_client.assert_called_once_with(
@ -289,8 +289,8 @@ async def test_initialization_with_connection_and_entity_operations(
assert len(tools) == 1
assert tools[0].name == "list_issues"
assert isinstance(tools[0], IntegrationConnectorTool)
assert tools[0].entity == "Issues"
assert tools[0].operation == "LIST_ENTITIES"
assert tools[0]._entity == "Issues"
assert tools[0]._operation == "LIST_ENTITIES"
@pytest.mark.asyncio
@ -314,7 +314,7 @@ async def test_initialization_with_connection_and_actions(
location,
connection=connection_name,
actions=actions_list,
tool_name=tool_name,
tool_name_prefix=tool_name,
tool_instructions=tool_instructions,
)
mock_integration_client.assert_called_once_with(
@ -332,8 +332,8 @@ async def test_initialization_with_connection_and_actions(
assert len(tools) == 1
assert tools[0].name == "list_issues_operation"
assert isinstance(tools[0], IntegrationConnectorTool)
assert tools[0].action == "CustomAction"
assert tools[0].operation == "EXECUTE_ACTION"
assert tools[0]._action == "CustomAction"
assert tools[0]._operation == "EXECUTE_ACTION"
def test_initialization_without_required_params(project, location):
@ -467,7 +467,7 @@ def test_initialization_with_connection_details(
location,
connection=connection_name,
entity_operations=entity_operations_list,
tool_name=tool_name,
tool_name_prefix=tool_name,
tool_instructions=tool_instructions,
)
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(