refactor: refactor application integration toolset to hide non-public field

PiperOrigin-RevId: 758469938
This commit is contained in:
Xiang (Sean) Zhou 2025-05-13 19:27:30 -07:00 committed by Copybara-Service
parent 00e0035c8a
commit 14cf910ce6
4 changed files with 55 additions and 53 deletions

View File

@ -33,7 +33,7 @@ jira_toolset = ApplicationIntegrationToolset(
location=connection_location, location=connection_location,
connection=connection_name, connection=connection_name,
entity_operations={"Issues": [], "Projects": []}, entity_operations={"Issues": [], "Projects": []},
tool_name="jira_issue_manager", tool_name_prefix="jira_issue_manager",
) )
root_agent = LlmAgent( root_agent = LlmAgent(

View File

@ -86,7 +86,7 @@ class ApplicationIntegrationToolset:
actions: Optional[str] = None, actions: Optional[str] = None,
# Optional parameter for the toolset. This is prepended to the generated # Optional parameter for the toolset. This is prepended to the generated
# tool/python function name. # tool/python function name.
tool_name: Optional[str] = "", tool_name_prefix: Optional[str] = "",
# Optional parameter for the toolset. This is appended to the generated # Optional parameter for the toolset. This is appended to the generated
# tool/python function description. # tool/python function description.
tool_instructions: Optional[str] = "", tool_instructions: Optional[str] = "",
@ -103,7 +103,7 @@ class ApplicationIntegrationToolset:
connection: The connection name. connection: The connection name.
entity_operations: The entity operations supported by the connection. entity_operations: The entity operations supported by the connection.
actions: The actions supported by the connection. actions: The actions supported by the connection.
tool_name: The name of the tool. tool_name_prefix: The name prefix of the generated tools.
tool_instructions: The instructions for the tool. tool_instructions: The instructions for the tool.
service_account_json: The service account configuration as a dictionary. service_account_json: The service account configuration as a dictionary.
Required if not using default service credential. Used for fetching Required if not using default service credential. Used for fetching
@ -122,15 +122,15 @@ class ApplicationIntegrationToolset:
""" """
self.project = project self.project = project
self.location = location self.location = location
self.integration = integration self._integration = integration
self.triggers = triggers self._triggers = triggers
self.connection = connection self._connection = connection
self.entity_operations = entity_operations self._entity_operations = entity_operations
self.actions = actions self._actions = actions
self.tool_name = tool_name self._tool_name_prefix = tool_name_prefix
self.tool_instructions = tool_instructions self._tool_instructions = tool_instructions
self.service_account_json = service_account_json self._service_account_json = service_account_json
self._tool_filter = tool_filter self.tool_filter = tool_filter
integration_client = IntegrationClient( integration_client = IntegrationClient(
project, project,
@ -151,7 +151,7 @@ class ApplicationIntegrationToolset:
) )
connection_details = connections_client.get_connection_details() connection_details = connections_client.get_connection_details()
spec = integration_client.get_openapi_spec_for_connection( spec = integration_client.get_openapi_spec_for_connection(
tool_name, tool_name_prefix,
tool_instructions, tool_instructions,
) )
else: else:
@ -159,15 +159,15 @@ class ApplicationIntegrationToolset:
"Invalid request, Either integration or (connection and" "Invalid request, Either integration or (connection and"
" (entity_operations or actions)) should be provided." " (entity_operations or actions)) should be provided."
) )
self.openapi_toolset = None self._openapi_toolset = None
self.tool = None self._tool = None
self._parse_spec_to_toolset(spec, connection_details) self._parse_spec_to_toolset(spec, connection_details)
def _parse_spec_to_toolset(self, spec_dict, connection_details): def _parse_spec_to_toolset(self, spec_dict, connection_details):
"""Parses the spec dict to OpenAPI toolset.""" """Parses the spec dict to OpenAPI toolset."""
if self.service_account_json: if self._service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json( sa_credential = ServiceAccountCredential.model_validate_json(
self.service_account_json self._service_account_json
) )
service_account = ServiceAccount( service_account = ServiceAccount(
service_account_credential=sa_credential, service_account_credential=sa_credential,
@ -186,12 +186,12 @@ class ApplicationIntegrationToolset:
) )
auth_scheme = HTTPBearer(bearerFormat="JWT") auth_scheme = HTTPBearer(bearerFormat="JWT")
if self.integration: if self._integration:
self.openapi_toolset = OpenAPIToolset( self._openapi_toolset = OpenAPIToolset(
spec_dict=spec_dict, spec_dict=spec_dict,
auth_credential=auth_credential, auth_credential=auth_credential,
auth_scheme=auth_scheme, auth_scheme=auth_scheme,
tool_filter=self._tool_filter, tool_filter=self.tool_filter,
) )
return return
@ -210,7 +210,7 @@ class ApplicationIntegrationToolset:
rest_api_tool.configure_auth_scheme(auth_scheme) rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential: if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential) rest_api_tool.configure_auth_credential(auth_credential)
self.tool = IntegrationConnectorTool( self._tool = IntegrationConnectorTool(
name=rest_api_tool.name, name=rest_api_tool.name,
description=rest_api_tool.description, description=rest_api_tool.description,
connection_name=connection_details["name"], connection_name=connection_details["name"],
@ -224,9 +224,11 @@ class ApplicationIntegrationToolset:
@override @override
async def get_tools(self) -> List[RestApiTool]: async def get_tools(self) -> List[RestApiTool]:
return [self.tool] if self.tool else await self.openapi_toolset.get_tools() return (
[self._tool] if self._tool else await self._openapi_toolset.get_tools()
)
@override @override
async def close(self) -> None: async def close(self) -> None:
if self.openapi_toolset: if self._openapi_toolset:
await self.openapi_toolset.close() await self._openapi_toolset.close()

View File

@ -101,18 +101,18 @@ class IntegrationConnectorTool(BaseTool):
name=name, name=name,
description=description, description=description,
) )
self.connection_name = connection_name self._connection_name = connection_name
self.connection_host = connection_host self._connection_host = connection_host
self.connection_service_name = connection_service_name self._connection_service_name = connection_service_name
self.entity = entity self._entity = entity
self.operation = operation self._operation = operation
self.action = action self._action = action
self.rest_api_tool = rest_api_tool self._rest_api_tool = rest_api_tool
@override @override
def _get_declaration(self) -> FunctionDeclaration: def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format.""" """Returns the function declaration in the Gemini Schema format."""
schema_dict = self.rest_api_tool._operation_parser.get_json_schema() schema_dict = self._rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS: for field in self.EXCLUDE_FIELDS:
if field in schema_dict['properties']: if field in schema_dict['properties']:
del schema_dict['properties'][field] del schema_dict['properties'][field]
@ -130,30 +130,30 @@ class IntegrationConnectorTool(BaseTool):
async def run_async( async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext] self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
args['connection_name'] = self.connection_name args['connection_name'] = self._connection_name
args['service_name'] = self.connection_service_name args['service_name'] = self._connection_service_name
args['host'] = self.connection_host args['host'] = self._connection_host
args['entity'] = self.entity args['entity'] = self._entity
args['operation'] = self.operation args['operation'] = self._operation
args['action'] = self.action args['action'] = self._action
logger.info('Running tool: %s with args: %s', self.name, args) logger.info('Running tool: %s with args: %s', self.name, args)
return self.rest_api_tool.call(args=args, tool_context=tool_context) return self._rest_api_tool.call(args=args, tool_context=tool_context)
def __str__(self): def __str__(self):
return ( return (
f'ApplicationIntegrationTool(name="{self.name}",' f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",' f' description="{self.description}",'
f' connection_name="{self.connection_name}", entity="{self.entity}",' f' connection_name="{self._connection_name}", entity="{self._entity}",'
f' operation="{self.operation}", action="{self.action}")' f' operation="{self._operation}", action="{self._action}")'
) )
def __repr__(self): def __repr__(self):
return ( return (
f'ApplicationIntegrationTool(name="{self.name}",' f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",' f' description="{self.description}",'
f' connection_name="{self.connection_name}",' f' connection_name="{self._connection_name}",'
f' connection_host="{self.connection_host}",' f' connection_host="{self._connection_host}",'
f' connection_service_name="{self.connection_service_name}",' f' connection_service_name="{self._connection_service_name}",'
f' entity="{self.entity}", operation="{self.operation}",' f' entity="{self._entity}", operation="{self._operation}",'
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})' f' action="{self._action}", rest_api_tool={repr(self._rest_api_tool)})'
) )

View File

@ -262,7 +262,7 @@ async def test_initialization_with_connection_and_entity_operations(
location, location,
connection=connection_name, connection=connection_name,
entity_operations=entity_operations_list, entity_operations=entity_operations_list,
tool_name=tool_name, tool_name_prefix=tool_name,
tool_instructions=tool_instructions, tool_instructions=tool_instructions,
) )
mock_integration_client.assert_called_once_with( mock_integration_client.assert_called_once_with(
@ -289,8 +289,8 @@ async def test_initialization_with_connection_and_entity_operations(
assert len(tools) == 1 assert len(tools) == 1
assert tools[0].name == "list_issues" assert tools[0].name == "list_issues"
assert isinstance(tools[0], IntegrationConnectorTool) assert isinstance(tools[0], IntegrationConnectorTool)
assert tools[0].entity == "Issues" assert tools[0]._entity == "Issues"
assert tools[0].operation == "LIST_ENTITIES" assert tools[0]._operation == "LIST_ENTITIES"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -314,7 +314,7 @@ async def test_initialization_with_connection_and_actions(
location, location,
connection=connection_name, connection=connection_name,
actions=actions_list, actions=actions_list,
tool_name=tool_name, tool_name_prefix=tool_name,
tool_instructions=tool_instructions, tool_instructions=tool_instructions,
) )
mock_integration_client.assert_called_once_with( mock_integration_client.assert_called_once_with(
@ -332,8 +332,8 @@ async def test_initialization_with_connection_and_actions(
assert len(tools) == 1 assert len(tools) == 1
assert tools[0].name == "list_issues_operation" assert tools[0].name == "list_issues_operation"
assert isinstance(tools[0], IntegrationConnectorTool) assert isinstance(tools[0], IntegrationConnectorTool)
assert tools[0].action == "CustomAction" assert tools[0]._action == "CustomAction"
assert tools[0].operation == "EXECUTE_ACTION" assert tools[0]._operation == "EXECUTE_ACTION"
def test_initialization_without_required_params(project, location): def test_initialization_without_required_params(project, location):
@ -467,7 +467,7 @@ def test_initialization_with_connection_details(
location, location,
connection=connection_name, connection=connection_name,
entity_operations=entity_operations_list, entity_operations=entity_operations_list,
tool_name=tool_name, tool_name_prefix=tool_name,
tool_instructions=tool_instructions, tool_instructions=tool_instructions,
) )
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(