refactor: refactor openapi toolset and tool parser to hide non public field

PiperOrigin-RevId: 758436303
This commit is contained in:
Xiang (Sean) Zhou
2025-05-13 17:13:53 -07:00
committed by Copybara-Service
parent 1f0fd7bfce
commit 9647426500
4 changed files with 87 additions and 87 deletions

View File

@@ -105,7 +105,7 @@ class OpenAPIToolset(BaseToolset):
"""
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter
@@ -115,7 +115,7 @@ class OpenAPIToolset(BaseToolset):
):
"""Configure auth scheme and credential for all tools."""
for tool in self.tools:
for tool in self._tools:
if auth_scheme:
tool.configure_auth_scheme(auth_scheme)
if auth_credential:
@@ -128,7 +128,7 @@ class OpenAPIToolset(BaseToolset):
"""Get all tools in the toolset."""
return [
tool
for tool in self.tools
for tool in self._tools
if self.tool_filter is None
or (
self.tool_filter(tool, readonly_context)
@@ -139,7 +139,7 @@ class OpenAPIToolset(BaseToolset):
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
matching_tool = filter(lambda t: t.name == tool_name, self._tools)
return next(matching_tool, None)
def _load_spec(

View File

@@ -45,14 +45,14 @@ class OperationParser:
should_parse: Whether to parse the operation during initialization.
"""
if isinstance(operation, dict):
self.operation = Operation.model_validate(operation)
self._operation = Operation.model_validate(operation)
elif isinstance(operation, str):
self.operation = Operation.model_validate_json(operation)
self._operation = Operation.model_validate_json(operation)
else:
self.operation = operation
self._operation = operation
self.params: List[ApiParameter] = []
self.return_value: Optional[ApiParameter] = None
self._params: List[ApiParameter] = []
self._return_value: Optional[ApiParameter] = None
if should_parse:
self._process_operation_parameters()
self._process_request_body()
@@ -67,13 +67,13 @@ class OperationParser:
return_value: Optional[ApiParameter] = None,
) -> 'OperationParser':
parser = cls(operation, should_parse=False)
parser.params = params
parser.return_value = return_value
parser._params = params
parser._return_value = return_value
return parser
def _process_operation_parameters(self):
"""Processes parameters from the OpenAPI operation."""
parameters = self.operation.parameters or []
parameters = self._operation.parameters or []
for param in parameters:
if isinstance(param, Parameter):
original_name = param.name
@@ -86,7 +86,7 @@ class OperationParser:
# param.required can be None
required = param.required if param.required is not None else False
self.params.append(
self._params.append(
ApiParameter(
original_name=original_name,
param_location=location,
@@ -98,7 +98,7 @@ class OperationParser:
def _process_request_body(self):
"""Processes the request body from the OpenAPI operation."""
request_body = self.operation.requestBody
request_body = self._operation.requestBody
if not request_body:
return
@@ -114,7 +114,7 @@ class OperationParser:
if schema and schema.type == 'object':
properties = schema.properties or {}
for prop_name, prop_details in properties.items():
self.params.append(
self._params.append(
ApiParameter(
original_name=prop_name,
param_location='body',
@@ -124,7 +124,7 @@ class OperationParser:
)
elif schema and schema.type == 'array':
self.params.append(
self._params.append(
ApiParameter(
original_name='array',
param_location='body',
@@ -133,7 +133,7 @@ class OperationParser:
)
)
else:
self.params.append(
self._params.append(
# Empty name for unnamed body param
ApiParameter(
original_name='',
@@ -147,7 +147,7 @@ class OperationParser:
def _dedupe_param_names(self):
"""Deduplicates parameter names to avoid conflicts."""
params_cnt = {}
for param in self.params:
for param in self._params:
name = param.py_name
if name not in params_cnt:
params_cnt[name] = 0
@@ -157,7 +157,7 @@ class OperationParser:
def _process_return_value(self) -> Parameter:
"""Returns a Parameter object representing the return type."""
responses = self.operation.responses or {}
responses = self._operation.responses or {}
# Default to Any if no 2xx response or if schema is missing
return_schema = Schema(type='Any')
@@ -174,7 +174,7 @@ class OperationParser:
return_schema = content[mime_type].schema_
break
self.return_value = ApiParameter(
self._return_value = ApiParameter(
original_name='',
param_location='',
param_schema=return_schema,
@@ -182,42 +182,42 @@ class OperationParser:
def get_function_name(self) -> str:
"""Returns the generated function name."""
operation_id = self.operation.operationId
operation_id = self._operation.operationId
if not operation_id:
raise ValueError('Operation ID is missing')
return to_snake_case(operation_id)[:60]
def get_return_type_hint(self) -> str:
"""Returns the return type hint string (like 'str', 'int', etc.)."""
return self.return_value.type_hint
return self._return_value.type_hint
def get_return_type_value(self) -> Any:
"""Returns the return type value (like str, int, List[str], etc.)."""
return self.return_value.type_value
return self._return_value.type_value
def get_parameters(self) -> List[ApiParameter]:
"""Returns the list of Parameter objects."""
return self.params
return self._params
def get_return_value(self) -> ApiParameter:
"""Returns the list of Parameter objects."""
return self.return_value
return self._return_value
def get_auth_scheme_name(self) -> str:
"""Returns the name of the auth scheme for this operation from the spec."""
if self.operation.security:
scheme_name = list(self.operation.security[0].keys())[0]
if self._operation.security:
scheme_name = list(self._operation.security[0].keys())[0]
return scheme_name
return ''
def get_pydoc_string(self) -> str:
"""Returns the generated PyDoc string."""
pydoc_params = [param.to_pydoc_string() for param in self.params]
pydoc_params = [param.to_pydoc_string() for param in self._params]
pydoc_description = (
self.operation.summary or self.operation.description or ''
self._operation.summary or self._operation.description or ''
)
pydoc_return = PydocHelper.generate_return_doc(
self.operation.responses or {}
self._operation.responses or {}
)
pydoc_arg_list = chr(10).join(
f' {param_doc}' for param_doc in pydoc_params
@@ -236,12 +236,12 @@ class OperationParser:
"""Returns the JSON schema for the function arguments."""
properties = {
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
for p in self.params
for p in self._params
}
return {
'properties': properties,
'required': [p.py_name for p in self.params if p.required],
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
'required': [p.py_name for p in self._params if p.required],
'title': f"{self._operation.operationId or 'unnamed'}_Arguments",
'type': 'object',
}
@@ -253,11 +253,11 @@ class OperationParser:
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=param.type_value,
)
for param in self.params
for param in self._params
]
def get_annotations(self) -> Dict[str, Any]:
"""Returns a dictionary of parameter annotations for the function."""
annotations = {p.py_name: p.type_value for p in self.params}
annotations = {p.py_name: p.type_value for p in self._params}
annotations['return'] = self.get_return_type_value()
return annotations