9608 for writable serializers

This commit is contained in:
Arthur 2023-03-16 11:28:57 -07:00
parent 7c5aeab347
commit 6fc5edda56
5 changed files with 94 additions and 10 deletions

View File

@ -1,4 +1,3 @@
import logging
import re
import typing
@ -17,9 +16,14 @@ from drf_spectacular.plumbing import (
)
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema
from rest_framework.relations import ManyRelatedField
from netbox.api.fields import ChoiceField, SerializedPKRelatedField
from netbox.api.serializers import WritableNestedSerializer
# see netbox.api.routers.NetBoxRouter
BULK_ACTIONS = ["bulk_destroy", "bulk_partial_update", "bulk_update"]
BULK_ACTIONS = ("bulk_destroy", "bulk_partial_update", "bulk_update")
WRITABLE_ACTIONS = ("PATCH", "POST", "PUT")
class FixTimeZoneSerializerField(OpenApiSerializerFieldExtension):
@ -54,6 +58,7 @@ class NetBoxAutoSchema(AutoSchema):
4. bulk operations don't have pagination
5. bulk delete should specify input
"""
writable_serializers = {}
@property
def is_bulk_action(self):
@ -100,6 +105,17 @@ class NetBoxAutoSchema(AutoSchema):
if self.is_bulk_action:
return type(serializer)(many=True)
# handle mapping for Writable serializers - adapted from dansheps original code
# for drf-yasg
if serializer is not None and self.method in WRITABLE_ACTIONS:
writable_class = self.get_writable_class(serializer)
if writable_class is not None:
if hasattr(serializer, "child"):
child_serializer = self.get_writable_class(serializer.child)
serializer = writable_class(context=serializer.context, child=child_serializer)
else:
serializer = writable_class(context=serializer.context)
return serializer
def get_response_serializers(self) -> typing.Any:
@ -111,6 +127,51 @@ class NetBoxAutoSchema(AutoSchema):
return response_serializers
def get_serializer_ref_name(self, serializer):
# from drf-yasg.utils
"""Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer')
:param serializer: Serializer instance
:return: Serializer's ``ref_name`` or ``None`` for inline serializer
:rtype: str or None
"""
serializer_meta = getattr(serializer, 'Meta', None)
serializer_name = type(serializer).__name__
if hasattr(serializer_meta, 'ref_name'):
ref_name = serializer_meta.ref_name
elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer):
ref_name = None
else:
ref_name = serializer_name
if ref_name.endswith('Serializer'):
ref_name = ref_name[:-len('Serializer')]
return ref_name
def get_writable_class(self, serializer):
properties = {}
fields = {} if hasattr(serializer, 'child') else serializer.fields
for child_name, child in fields.items():
if isinstance(child, (ChoiceField, WritableNestedSerializer)):
properties[child_name] = None
elif isinstance(child, ManyRelatedField) and isinstance(child.child_relation, SerializedPKRelatedField):
properties[child_name] = None
if not properties:
return None
if type(serializer) not in self.writable_serializers:
writable_name = 'Writable' + type(serializer).__name__
meta_class = getattr(type(serializer), 'Meta', None)
if meta_class:
ref_name = 'Writable' + self.get_serializer_ref_name(serializer)
writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name})
properties['Meta'] = writable_meta
self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)
writable_class = self.writable_serializers[type(serializer)]
return writable_class
def get_filter_backends(self):
# bulk operations don't have filter params
if self.is_bulk_action:

View File

@ -618,7 +618,7 @@ class ConnectedDeviceViewSet(ViewSet):
required=True,
type=OpenApiTypes.STR
)
serializer_class = serializers.DeviceSerializer # for drf-spectacular
serializer_class = serializers.DeviceSerializer
def get_view_name(self):
return "Connected Device Locator"

View File

@ -209,9 +209,8 @@ def get_results_limit(request):
class AvailableASNsView(ObjectValidationMixin, APIView):
queryset = ASN.objects.all()
serializer_class = serializers.AvailableASNSerializer # drf-spectacular
@extend_schema(methods=["get"], responses={200: serializers.AvailablePrefixSerializer(many=True)})
@extend_schema(methods=["get"], responses={200: serializers.AvailableASNSerializer(many=True)})
def get(self, request, pk):
asnrange = get_object_or_404(ASNRange.objects.restrict(request.user), pk=pk)
limit = get_results_limit(request)
@ -272,10 +271,15 @@ class AvailableASNsView(ObjectValidationMixin, APIView):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.request.method == "GET":
return serializers.AvailableASNSerializer
return serializers.ASNSerializer
class AvailablePrefixesView(ObjectValidationMixin, APIView):
queryset = Prefix.objects.all()
serializer_class = serializers.PrefixSerializer # for drf-spectacular
@extend_schema(methods=["get"], responses={200: serializers.AvailablePrefixSerializer(many=True)})
def get(self, request, pk):
@ -352,10 +356,15 @@ class AvailablePrefixesView(ObjectValidationMixin, APIView):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.request.method == "GET":
return serializers.AvailablePrefixSerializer
return serializers.PrefixLengthSerializer
class AvailableIPAddressesView(ObjectValidationMixin, APIView):
queryset = IPAddress.objects.all()
serializer_class = serializers.IPAddressSerializer # for drf-spectacular
def get_parent(self, request, pk):
raise NotImplemented()
@ -424,6 +433,12 @@ class AvailableIPAddressesView(ObjectValidationMixin, APIView):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.request.method == "GET":
return serializers.AvailableIPSerializer
return serializers.IPAddressSerializer
class PrefixAvailableIPAddressesView(AvailableIPAddressesView):
@ -439,7 +454,6 @@ class IPRangeAvailableIPAddressesView(AvailableIPAddressesView):
class AvailableVLANsView(ObjectValidationMixin, APIView):
queryset = VLAN.objects.all()
serializer_class = serializers.VLANSerializer # for drf-spectacular
@extend_schema(methods=["get"], responses={200: serializers.AvailableVLANSerializer(many=True)})
def get(self, request, pk):
@ -506,3 +520,9 @@ class AvailableVLANsView(ObjectValidationMixin, APIView):
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.request.method == "GET":
return serializers.AvailableVLANSerializer
return serializers.VLANSerializer

View File

@ -14,6 +14,7 @@ __all__ = (
)
@extend_schema_field(OpenApiTypes.STR)
class ChoiceField(serializers.Field):
"""
Represent a ChoiceField as {'value': <DB value>, 'label': <string>}. Accepts a single value on write.

View File

@ -69,9 +69,8 @@ class TokenProvisionView(APIView):
Non-authenticated REST API endpoint via which a user may create a Token.
"""
permission_classes = []
serializer_class = serializers.TokenSerializer # for drf-spectacular
@extend_schema(methods=["post"], responses={201: serializers.TokenSerializer})
# @extend_schema(methods=["post"], responses={201: serializers.TokenSerializer})
def post(self, request):
serializer = serializers.TokenProvisionSerializer(data=request.data)
serializer.is_valid()
@ -94,6 +93,9 @@ class TokenProvisionView(APIView):
return Response(data, status=HTTP_201_CREATED)
def get_serializer_class(self):
return serializers.TokenSerializer
#
# ObjectPermissions