refactor: refactor openapi toolset and tool parser to hide non public field

PiperOrigin-RevId: 758436303
This commit is contained in:
Xiang (Sean) Zhou 2025-05-13 17:13:53 -07:00 committed by Copybara-Service
parent 1f0fd7bfce
commit 9647426500
4 changed files with 87 additions and 87 deletions

View File

@ -105,7 +105,7 @@ class OpenAPIToolset(BaseToolset):
""" """
if not spec_dict: if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type) spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict)) self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential: if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential) self._configure_auth_all(auth_scheme, auth_credential)
self.tool_filter = tool_filter self.tool_filter = tool_filter
@ -115,7 +115,7 @@ class OpenAPIToolset(BaseToolset):
): ):
"""Configure auth scheme and credential for all tools.""" """Configure auth scheme and credential for all tools."""
for tool in self.tools: for tool in self._tools:
if auth_scheme: if auth_scheme:
tool.configure_auth_scheme(auth_scheme) tool.configure_auth_scheme(auth_scheme)
if auth_credential: if auth_credential:
@ -128,7 +128,7 @@ class OpenAPIToolset(BaseToolset):
"""Get all tools in the toolset.""" """Get all tools in the toolset."""
return [ return [
tool tool
for tool in self.tools for tool in self._tools
if self.tool_filter is None if self.tool_filter is None
or ( or (
self.tool_filter(tool, readonly_context) self.tool_filter(tool, readonly_context)
@ -139,7 +139,7 @@ class OpenAPIToolset(BaseToolset):
def get_tool(self, tool_name: str) -> Optional[RestApiTool]: def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name.""" """Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools) matching_tool = filter(lambda t: t.name == tool_name, self._tools)
return next(matching_tool, None) return next(matching_tool, None)
def _load_spec( def _load_spec(

View File

@ -45,14 +45,14 @@ class OperationParser:
should_parse: Whether to parse the operation during initialization. should_parse: Whether to parse the operation during initialization.
""" """
if isinstance(operation, dict): if isinstance(operation, dict):
self.operation = Operation.model_validate(operation) self._operation = Operation.model_validate(operation)
elif isinstance(operation, str): elif isinstance(operation, str):
self.operation = Operation.model_validate_json(operation) self._operation = Operation.model_validate_json(operation)
else: else:
self.operation = operation self._operation = operation
self.params: List[ApiParameter] = [] self._params: List[ApiParameter] = []
self.return_value: Optional[ApiParameter] = None self._return_value: Optional[ApiParameter] = None
if should_parse: if should_parse:
self._process_operation_parameters() self._process_operation_parameters()
self._process_request_body() self._process_request_body()
@ -67,13 +67,13 @@ class OperationParser:
return_value: Optional[ApiParameter] = None, return_value: Optional[ApiParameter] = None,
) -> 'OperationParser': ) -> 'OperationParser':
parser = cls(operation, should_parse=False) parser = cls(operation, should_parse=False)
parser.params = params parser._params = params
parser.return_value = return_value parser._return_value = return_value
return parser return parser
def _process_operation_parameters(self): def _process_operation_parameters(self):
"""Processes parameters from the OpenAPI operation.""" """Processes parameters from the OpenAPI operation."""
parameters = self.operation.parameters or [] parameters = self._operation.parameters or []
for param in parameters: for param in parameters:
if isinstance(param, Parameter): if isinstance(param, Parameter):
original_name = param.name original_name = param.name
@ -86,7 +86,7 @@ class OperationParser:
# param.required can be None # param.required can be None
required = param.required if param.required is not None else False required = param.required if param.required is not None else False
self.params.append( self._params.append(
ApiParameter( ApiParameter(
original_name=original_name, original_name=original_name,
param_location=location, param_location=location,
@ -98,7 +98,7 @@ class OperationParser:
def _process_request_body(self): def _process_request_body(self):
"""Processes the request body from the OpenAPI operation.""" """Processes the request body from the OpenAPI operation."""
request_body = self.operation.requestBody request_body = self._operation.requestBody
if not request_body: if not request_body:
return return
@ -114,7 +114,7 @@ class OperationParser:
if schema and schema.type == 'object': if schema and schema.type == 'object':
properties = schema.properties or {} properties = schema.properties or {}
for prop_name, prop_details in properties.items(): for prop_name, prop_details in properties.items():
self.params.append( self._params.append(
ApiParameter( ApiParameter(
original_name=prop_name, original_name=prop_name,
param_location='body', param_location='body',
@ -124,7 +124,7 @@ class OperationParser:
) )
elif schema and schema.type == 'array': elif schema and schema.type == 'array':
self.params.append( self._params.append(
ApiParameter( ApiParameter(
original_name='array', original_name='array',
param_location='body', param_location='body',
@ -133,7 +133,7 @@ class OperationParser:
) )
) )
else: else:
self.params.append( self._params.append(
# Empty name for unnamed body param # Empty name for unnamed body param
ApiParameter( ApiParameter(
original_name='', original_name='',
@ -147,7 +147,7 @@ class OperationParser:
def _dedupe_param_names(self): def _dedupe_param_names(self):
"""Deduplicates parameter names to avoid conflicts.""" """Deduplicates parameter names to avoid conflicts."""
params_cnt = {} params_cnt = {}
for param in self.params: for param in self._params:
name = param.py_name name = param.py_name
if name not in params_cnt: if name not in params_cnt:
params_cnt[name] = 0 params_cnt[name] = 0
@ -157,7 +157,7 @@ class OperationParser:
def _process_return_value(self) -> Parameter: def _process_return_value(self) -> Parameter:
"""Returns a Parameter object representing the return type.""" """Returns a Parameter object representing the return type."""
responses = self.operation.responses or {} responses = self._operation.responses or {}
# Default to Any if no 2xx response or if schema is missing # Default to Any if no 2xx response or if schema is missing
return_schema = Schema(type='Any') return_schema = Schema(type='Any')
@ -174,7 +174,7 @@ class OperationParser:
return_schema = content[mime_type].schema_ return_schema = content[mime_type].schema_
break break
self.return_value = ApiParameter( self._return_value = ApiParameter(
original_name='', original_name='',
param_location='', param_location='',
param_schema=return_schema, param_schema=return_schema,
@ -182,42 +182,42 @@ class OperationParser:
def get_function_name(self) -> str: def get_function_name(self) -> str:
"""Returns the generated function name.""" """Returns the generated function name."""
operation_id = self.operation.operationId operation_id = self._operation.operationId
if not operation_id: if not operation_id:
raise ValueError('Operation ID is missing') raise ValueError('Operation ID is missing')
return to_snake_case(operation_id)[:60] return to_snake_case(operation_id)[:60]
def get_return_type_hint(self) -> str: def get_return_type_hint(self) -> str:
"""Returns the return type hint string (like 'str', 'int', etc.).""" """Returns the return type hint string (like 'str', 'int', etc.)."""
return self.return_value.type_hint return self._return_value.type_hint
def get_return_type_value(self) -> Any: def get_return_type_value(self) -> Any:
"""Returns the return type value (like str, int, List[str], etc.).""" """Returns the return type value (like str, int, List[str], etc.)."""
return self.return_value.type_value return self._return_value.type_value
def get_parameters(self) -> List[ApiParameter]: def get_parameters(self) -> List[ApiParameter]:
"""Returns the list of Parameter objects.""" """Returns the list of Parameter objects."""
return self.params return self._params
def get_return_value(self) -> ApiParameter: def get_return_value(self) -> ApiParameter:
"""Returns the list of Parameter objects.""" """Returns the list of Parameter objects."""
return self.return_value return self._return_value
def get_auth_scheme_name(self) -> str: def get_auth_scheme_name(self) -> str:
"""Returns the name of the auth scheme for this operation from the spec.""" """Returns the name of the auth scheme for this operation from the spec."""
if self.operation.security: if self._operation.security:
scheme_name = list(self.operation.security[0].keys())[0] scheme_name = list(self._operation.security[0].keys())[0]
return scheme_name return scheme_name
return '' return ''
def get_pydoc_string(self) -> str: def get_pydoc_string(self) -> str:
"""Returns the generated PyDoc string.""" """Returns the generated PyDoc string."""
pydoc_params = [param.to_pydoc_string() for param in self.params] pydoc_params = [param.to_pydoc_string() for param in self._params]
pydoc_description = ( pydoc_description = (
self.operation.summary or self.operation.description or '' self._operation.summary or self._operation.description or ''
) )
pydoc_return = PydocHelper.generate_return_doc( pydoc_return = PydocHelper.generate_return_doc(
self.operation.responses or {} self._operation.responses or {}
) )
pydoc_arg_list = chr(10).join( pydoc_arg_list = chr(10).join(
f' {param_doc}' for param_doc in pydoc_params f' {param_doc}' for param_doc in pydoc_params
@ -236,12 +236,12 @@ class OperationParser:
"""Returns the JSON schema for the function arguments.""" """Returns the JSON schema for the function arguments."""
properties = { properties = {
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True) p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
for p in self.params for p in self._params
} }
return { return {
'properties': properties, 'properties': properties,
'required': [p.py_name for p in self.params if p.required], 'required': [p.py_name for p in self._params if p.required],
'title': f"{self.operation.operationId or 'unnamed'}_Arguments", 'title': f"{self._operation.operationId or 'unnamed'}_Arguments",
'type': 'object', 'type': 'object',
} }
@ -253,11 +253,11 @@ class OperationParser:
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=param.type_value, annotation=param.type_value,
) )
for param in self.params for param in self._params
] ]
def get_annotations(self) -> Dict[str, Any]: def get_annotations(self) -> Dict[str, Any]:
"""Returns a dictionary of parameter annotations for the function.""" """Returns a dictionary of parameter annotations for the function."""
annotations = {p.py_name: p.type_value for p in self.params} annotations = {p.py_name: p.type_value for p in self._params}
annotations['return'] = self.get_return_type_value() annotations['return'] = self.get_return_type_value()
return annotations return annotations

View File

@ -47,18 +47,18 @@ def openapi_spec() -> Dict:
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict): def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
"""Test initialization of OpenAPIToolset with a dictionary.""" """Test initialization of OpenAPIToolset with a dictionary."""
toolset = OpenAPIToolset(spec_dict=openapi_spec) toolset = OpenAPIToolset(spec_dict=openapi_spec)
assert isinstance(toolset.tools, list) assert isinstance(toolset._tools, list)
assert len(toolset.tools) == 5 assert len(toolset._tools) == 5
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools) assert all(isinstance(tool, RestApiTool) for tool in toolset._tools)
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict): def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
"""Test initialization of OpenAPIToolset with a YAML string.""" """Test initialization of OpenAPIToolset with a YAML string."""
spec_str = yaml.dump(openapi_spec) spec_str = yaml.dump(openapi_spec)
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml") toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
assert isinstance(toolset.tools, list) assert isinstance(toolset._tools, list)
assert len(toolset.tools) == 5 assert len(toolset._tools) == 5
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools) assert all(isinstance(tool, RestApiTool) for tool in toolset._tools)
def test_openapi_toolset_tool_existing(openapi_spec: Dict): def test_openapi_toolset_tool_existing(openapi_spec: Dict):
@ -134,6 +134,6 @@ def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
auth_scheme=auth_scheme, auth_scheme=auth_scheme,
auth_credential=auth_credential, auth_credential=auth_credential,
) )
for tool in toolset.tools: for tool in toolset._tools:
assert tool.auth_scheme == auth_scheme assert tool.auth_scheme == auth_scheme
assert tool.auth_credential == auth_credential assert tool.auth_credential == auth_credential

View File

@ -78,31 +78,31 @@ def sample_operation() -> Operation:
def test_operation_parser_initialization(sample_operation): def test_operation_parser_initialization(sample_operation):
"""Test initialization of OperationParser.""" """Test initialization of OperationParser."""
parser = OperationParser(sample_operation) parser = OperationParser(sample_operation)
assert parser.operation == sample_operation assert parser._operation == sample_operation
assert len(parser.params) == 4 # 2 params + 2 request body props assert len(parser._params) == 4 # 2 params + 2 request body props
assert parser.return_value is not None assert parser._return_value is not None
def test_process_operation_parameters(sample_operation): def test_process_operation_parameters(sample_operation):
"""Test _process_operation_parameters method.""" """Test _process_operation_parameters method."""
parser = OperationParser(sample_operation, should_parse=False) parser = OperationParser(sample_operation, should_parse=False)
parser._process_operation_parameters() parser._process_operation_parameters()
assert len(parser.params) == 2 assert len(parser._params) == 2
assert parser.params[0].original_name == 'param1' assert parser._params[0].original_name == 'param1'
assert parser.params[0].param_location == 'query' assert parser._params[0].param_location == 'query'
assert parser.params[1].original_name == 'param2' assert parser._params[1].original_name == 'param2'
assert parser.params[1].param_location == 'header' assert parser._params[1].param_location == 'header'
def test_process_request_body(sample_operation): def test_process_request_body(sample_operation):
"""Test _process_request_body method.""" """Test _process_request_body method."""
parser = OperationParser(sample_operation, should_parse=False) parser = OperationParser(sample_operation, should_parse=False)
parser._process_request_body() parser._process_request_body()
assert len(parser.params) == 2 # 2 properties in request body assert len(parser._params) == 2 # 2 properties in request body
assert parser.params[0].original_name == 'prop1' assert parser._params[0].original_name == 'prop1'
assert parser.params[0].param_location == 'body' assert parser._params[0].param_location == 'body'
assert parser.params[1].original_name == 'prop2' assert parser._params[1].original_name == 'prop2'
assert parser.params[1].param_location == 'body' assert parser._params[1].param_location == 'body'
def test_process_request_body_array(): def test_process_request_body_array():
@ -132,20 +132,20 @@ def test_process_request_body_array():
parser = OperationParser(operation, should_parse=False) parser = OperationParser(operation, should_parse=False)
parser._process_request_body() parser._process_request_body()
assert len(parser.params) == 1 assert len(parser._params) == 1
assert parser.params[0].original_name == 'array' assert parser._params[0].original_name == 'array'
assert parser.params[0].param_location == 'body' assert parser._params[0].param_location == 'body'
# Check that schema is correctly propagated and is a dictionary # Check that schema is correctly propagated and is a dictionary
assert parser.params[0].param_schema.type == 'array' assert parser._params[0].param_schema.type == 'array'
assert parser.params[0].param_schema.items.type == 'object' assert parser._params[0].param_schema.items.type == 'object'
assert 'item_prop1' in parser.params[0].param_schema.items.properties assert 'item_prop1' in parser._params[0].param_schema.items.properties
assert 'item_prop2' in parser.params[0].param_schema.items.properties assert 'item_prop2' in parser._params[0].param_schema.items.properties
assert ( assert (
parser.params[0].param_schema.items.properties['item_prop1'].description parser._params[0].param_schema.items.properties['item_prop1'].description
== 'Item Property 1' == 'Item Property 1'
) )
assert ( assert (
parser.params[0].param_schema.items.properties['item_prop2'].description parser._params[0].param_schema.items.properties['item_prop2'].description
== 'Item Property 2' == 'Item Property 2'
) )
@ -159,9 +159,9 @@ def test_process_request_body_no_name():
) )
parser = OperationParser(operation, should_parse=False) parser = OperationParser(operation, should_parse=False)
parser._process_request_body() parser._process_request_body()
assert len(parser.params) == 1 assert len(parser._params) == 1
assert parser.params[0].original_name == '' # No name assert parser._params[0].original_name == '' # No name
assert parser.params[0].param_location == 'body' assert parser._params[0].param_location == 'body'
def test_process_request_body_empty_object(): def test_process_request_body_empty_object():
@ -173,30 +173,30 @@ def test_process_request_body_empty_object():
) )
parser = OperationParser(operation, should_parse=False) parser = OperationParser(operation, should_parse=False)
parser._process_request_body() parser._process_request_body()
assert len(parser.params) == 0 assert len(parser._params) == 0
def test_dedupe_param_names(sample_operation): def test_dedupe_param_names(sample_operation):
"""Test _dedupe_param_names method.""" """Test _dedupe_param_names method."""
parser = OperationParser(sample_operation, should_parse=False) parser = OperationParser(sample_operation, should_parse=False)
# Add duplicate named parameters. # Add duplicate named parameters.
parser.params = [ parser._params = [
ApiParameter(original_name='test', param_location='', param_schema={}), ApiParameter(original_name='test', param_location='', param_schema={}),
ApiParameter(original_name='test', param_location='', param_schema={}), ApiParameter(original_name='test', param_location='', param_schema={}),
ApiParameter(original_name='test', param_location='', param_schema={}), ApiParameter(original_name='test', param_location='', param_schema={}),
] ]
parser._dedupe_param_names() parser._dedupe_param_names()
assert parser.params[0].py_name == 'test' assert parser._params[0].py_name == 'test'
assert parser.params[1].py_name == 'test_0' assert parser._params[1].py_name == 'test_0'
assert parser.params[2].py_name == 'test_1' assert parser._params[2].py_name == 'test_1'
def test_process_return_value(sample_operation): def test_process_return_value(sample_operation):
"""Test _process_return_value method.""" """Test _process_return_value method."""
parser = OperationParser(sample_operation, should_parse=False) parser = OperationParser(sample_operation, should_parse=False)
parser._process_return_value() parser._process_return_value()
assert parser.return_value is not None assert parser._return_value is not None
assert parser.return_value.type_hint == 'str' assert parser._return_value.type_hint == 'str'
def test_process_return_value_no_2xx(sample_operation): def test_process_return_value_no_2xx(sample_operation):
@ -206,8 +206,8 @@ def test_process_return_value_no_2xx(sample_operation):
) )
parser = OperationParser(operation_no_2xx, should_parse=False) parser = OperationParser(operation_no_2xx, should_parse=False)
parser._process_return_value() parser._process_return_value()
assert parser.return_value is not None assert parser._return_value is not None
assert parser.return_value.type_hint == 'Any' assert parser._return_value.type_hint == 'Any'
def test_process_return_value_multiple_2xx(sample_operation): def test_process_return_value_multiple_2xx(sample_operation):
@ -242,10 +242,10 @@ def test_process_return_value_multiple_2xx(sample_operation):
parser = OperationParser(operation_multi_2xx, should_parse=False) parser = OperationParser(operation_multi_2xx, should_parse=False)
parser._process_return_value() parser._process_return_value()
assert parser.return_value is not None assert parser._return_value is not None
# Take the content type of the 200 response since it's the smallest response # Take the content type of the 200 response since it's the smallest response
# code # code
assert parser.return_value.param_schema.type == 'boolean' assert parser._return_value.param_schema.type == 'boolean'
def test_process_return_value_no_content(sample_operation): def test_process_return_value_no_content(sample_operation):
@ -255,7 +255,7 @@ def test_process_return_value_no_content(sample_operation):
) )
parser = OperationParser(operation_no_content, should_parse=False) parser = OperationParser(operation_no_content, should_parse=False)
parser._process_return_value() parser._process_return_value()
assert parser.return_value.type_hint == 'Any' assert parser._return_value.type_hint == 'Any'
def test_process_return_value_no_schema(sample_operation): def test_process_return_value_no_schema(sample_operation):
@ -270,7 +270,7 @@ def test_process_return_value_no_schema(sample_operation):
) )
parser = OperationParser(operation_no_schema, should_parse=False) parser = OperationParser(operation_no_schema, should_parse=False)
parser._process_return_value() parser._process_return_value()
assert parser.return_value.type_hint == 'Any' assert parser._return_value.type_hint == 'Any'
def test_get_function_name(sample_operation): def test_get_function_name(sample_operation):
@ -389,9 +389,9 @@ def test_load():
parser = OperationParser.load(operation, params, return_value) parser = OperationParser.load(operation, params, return_value)
assert isinstance(parser, OperationParser) assert isinstance(parser, OperationParser)
assert parser.operation == operation assert parser._operation == operation
assert parser.params == params assert parser._params == params
assert parser.return_value == return_value assert parser._return_value == return_value
assert ( assert (
parser.get_function_name() == 'my_op' parser.get_function_name() == 'my_op'
) # Check that the operation is loaded ) # Check that the operation is loaded
@ -412,7 +412,7 @@ def test_operation_parser_with_dict():
}, },
} }
parser = OperationParser(operation_dict) parser = OperationParser(operation_dict)
assert parser.operation.operationId == 'test_dict_operation' assert parser._operation.operationId == 'test_dict_operation'
assert len(parser.params) == 1 assert len(parser._params) == 1
assert parser.params[0].original_name == 'dict_param' assert parser._params[0].original_name == 'dict_param'
assert parser.return_value.type_hint == 'str' assert parser._return_value.type_hint == 'str'