Add custom get_operation_id() method to avoid monkey-patching coreapi

This commit is contained in:
Jeremy Stretch 2020-11-11 14:25:43 -05:00
parent 0e1fb87153
commit dd1aac32f2
2 changed files with 12 additions and 15 deletions

View File

@ -1,5 +1,3 @@
from rest_framework.schemas import coreapi
from .fields import ChoiceField, ContentTypeField, SerializedPKRelatedField, TimeZoneField from .fields import ChoiceField, ContentTypeField, SerializedPKRelatedField, TimeZoneField
from .routers import OrderedDefaultRouter from .routers import OrderedDefaultRouter
from .serializers import BulkOperationSerializer, ValidatedModelSerializer, WritableNestedSerializer from .serializers import BulkOperationSerializer, ValidatedModelSerializer, WritableNestedSerializer
@ -15,16 +13,3 @@ __all__ = (
'ValidatedModelSerializer', 'ValidatedModelSerializer',
'WritableNestedSerializer', 'WritableNestedSerializer',
) )
def is_custom_action(action):
return action not in {
# Default actions
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy',
# Bulk operations
'bulk_update', 'bulk_partial_update', 'bulk_destroy',
}
# Monkey-patch DRF to treat bulk_destroy() as a non-custom action (see #3436)
coreapi.is_custom_action = is_custom_action

View File

@ -12,6 +12,18 @@ from netbox.api import ChoiceField, SerializedPKRelatedField, WritableNestedSeri
class NetBoxSwaggerAutoSchema(SwaggerAutoSchema): class NetBoxSwaggerAutoSchema(SwaggerAutoSchema):
writable_serializers = {} writable_serializers = {}
def get_operation_id(self, operation_keys=None):
operation_keys = operation_keys or self.operation_keys
operation_id = self.overrides.get('operation_id', '')
if not operation_id:
# Overwrite the action for bulk update/bulk delete views to ensure they get an operation ID that's
# unique from their single-object counterparts (see #3436)
if operation_keys[-1] in ('delete', 'partial_update', 'update') and not self.view.detail:
operation_keys[-1] = f'bulk_{operation_keys[-1]}'
operation_id = '_'.join(operation_keys)
return operation_id
def get_request_serializer(self): def get_request_serializer(self):
serializer = super().get_request_serializer() serializer = super().get_request_serializer()