The id__in field is a csv-separated string of ids

drf_yasg is interpreting it as a number because NumericInFilter inherits
from django's NumberFilter which explicitly identifies as being a
DecimalField.
This commit is contained in:
Dave Cameron 2018-03-15 16:51:57 -04:00
parent 53e4e74930
commit e071b7dfd5
2 changed files with 15 additions and 1 deletions

View File

@ -260,6 +260,10 @@ SWAGGER_SETTINGS = {
'drf_yasg.inspectors.SimpleFieldInspector', 'drf_yasg.inspectors.SimpleFieldInspector',
'drf_yasg.inspectors.StringDefaultFieldInspector', 'drf_yasg.inspectors.StringDefaultFieldInspector',
], ],
'DEFAULT_FILTER_INSPECTORS': [
'utilities.custom_inspectors.IdInFilterInspector',
'drf_yasg.inspectors.CoreAPICompatInspector',
],
'DEFAULT_PAGINATOR_INSPECTORS': [ 'DEFAULT_PAGINATOR_INSPECTORS': [
'utilities.custom_inspectors.NullablePaginatorInspector', 'utilities.custom_inspectors.NullablePaginatorInspector',
'drf_yasg.inspectors.DjangoRestResponsePagination', 'drf_yasg.inspectors.DjangoRestResponsePagination',

View File

@ -1,5 +1,5 @@
from drf_yasg import openapi from drf_yasg import openapi
from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector, FilterInspector
from rest_framework.fields import ChoiceField from rest_framework.fields import ChoiceField
from extras.api.customfields import CustomFieldsSerializer from extras.api.customfields import CustomFieldsSerializer
@ -53,6 +53,16 @@ class NullableBooleanFieldInspector(FieldInspector):
return result return result
class IdInFilterInspector(FilterInspector):
def process_result(self, result, method_name, obj, **kwargs):
if isinstance(result, list):
params = [p for p in result if isinstance(p, openapi.Parameter) and p.name == 'id__in']
for p in params:
p.type = 'string'
return result
class NullablePaginatorInspector(PaginatorInspector): class NullablePaginatorInspector(PaginatorInspector):
def process_result(self, result, method_name, obj, **kwargs): def process_result(self, result, method_name, obj, **kwargs):
if method_name == 'get_paginated_response' and isinstance(result, openapi.Schema): if method_name == 'get_paginated_response' and isinstance(result, openapi.Schema):