diff --git a/netbox/core/api/schema.py b/netbox/core/api/schema.py index f1aec514a..c475e770c 100644 --- a/netbox/core/api/schema.py +++ b/netbox/core/api/schema.py @@ -1,4 +1,3 @@ -import logging import re import typing @@ -17,9 +16,14 @@ from drf_spectacular.plumbing import ( ) from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import extend_schema +from rest_framework.relations import ManyRelatedField + +from netbox.api.fields import ChoiceField, SerializedPKRelatedField +from netbox.api.serializers import WritableNestedSerializer # see netbox.api.routers.NetBoxRouter -BULK_ACTIONS = ["bulk_destroy", "bulk_partial_update", "bulk_update"] +BULK_ACTIONS = ("bulk_destroy", "bulk_partial_update", "bulk_update") +WRITABLE_ACTIONS = ("PATCH", "POST", "PUT") class FixTimeZoneSerializerField(OpenApiSerializerFieldExtension): @@ -54,6 +58,7 @@ class NetBoxAutoSchema(AutoSchema): 4. bulk operations don't have pagination 5. bulk delete should specify input """ + writable_serializers = {} @property def is_bulk_action(self): @@ -100,6 +105,17 @@ class NetBoxAutoSchema(AutoSchema): if self.is_bulk_action: return type(serializer)(many=True) + # handle mapping for Writable serializers - adapted from dansheps original code + # for drf-yasg + if serializer is not None and self.method in WRITABLE_ACTIONS: + writable_class = self.get_writable_class(serializer) + if writable_class is not None: + if hasattr(serializer, "child"): + child_serializer = self.get_writable_class(serializer.child) + serializer = writable_class(context=serializer.context, child=child_serializer) + else: + serializer = writable_class(context=serializer.context) + return serializer def get_response_serializers(self) -> typing.Any: @@ -111,6 +127,51 @@ class NetBoxAutoSchema(AutoSchema): return response_serializers + def get_serializer_ref_name(self, serializer): + # from drf-yasg.utils + """Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer') + :param serializer: Serializer instance + :return: Serializer's ``ref_name`` or ``None`` for inline serializer + :rtype: str or None + """ + serializer_meta = getattr(serializer, 'Meta', None) + serializer_name = type(serializer).__name__ + if hasattr(serializer_meta, 'ref_name'): + ref_name = serializer_meta.ref_name + elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer): + ref_name = None + else: + ref_name = serializer_name + if ref_name.endswith('Serializer'): + ref_name = ref_name[:-len('Serializer')] + return ref_name + + def get_writable_class(self, serializer): + properties = {} + fields = {} if hasattr(serializer, 'child') else serializer.fields + + for child_name, child in fields.items(): + if isinstance(child, (ChoiceField, WritableNestedSerializer)): + properties[child_name] = None + elif isinstance(child, ManyRelatedField) and isinstance(child.child_relation, SerializedPKRelatedField): + properties[child_name] = None + + if not properties: + return None + + if type(serializer) not in self.writable_serializers: + writable_name = 'Writable' + type(serializer).__name__ + meta_class = getattr(type(serializer), 'Meta', None) + if meta_class: + ref_name = 'Writable' + self.get_serializer_ref_name(serializer) + writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name}) + properties['Meta'] = writable_meta + + self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties) + + writable_class = self.writable_serializers[type(serializer)] + return writable_class + def get_filter_backends(self): # bulk operations don't have filter params if self.is_bulk_action: diff --git a/netbox/dcim/api/views.py b/netbox/dcim/api/views.py index fd7f26fc0..5d015de0c 100644 --- a/netbox/dcim/api/views.py +++ b/netbox/dcim/api/views.py @@ -618,7 +618,7 @@ class ConnectedDeviceViewSet(ViewSet): required=True, type=OpenApiTypes.STR ) - serializer_class = serializers.DeviceSerializer # for drf-spectacular + serializer_class = serializers.DeviceSerializer def get_view_name(self): return "Connected Device Locator" diff --git a/netbox/ipam/api/views.py b/netbox/ipam/api/views.py index 6337b84c9..5263b049a 100644 --- a/netbox/ipam/api/views.py +++ b/netbox/ipam/api/views.py @@ -209,9 +209,8 @@ def get_results_limit(request): class AvailableASNsView(ObjectValidationMixin, APIView): queryset = ASN.objects.all() - serializer_class = serializers.AvailableASNSerializer # drf-spectacular - @extend_schema(methods=["get"], responses={200: serializers.AvailablePrefixSerializer(many=True)}) + @extend_schema(methods=["get"], responses={200: serializers.AvailableASNSerializer(many=True)}) def get(self, request, pk): asnrange = get_object_or_404(ASNRange.objects.restrict(request.user), pk=pk) limit = get_results_limit(request) @@ -272,10 +271,15 @@ class AvailableASNsView(ObjectValidationMixin, APIView): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def get_serializer_class(self): + if self.request.method == "GET": + return serializers.AvailableASNSerializer + + return serializers.ASNSerializer + class AvailablePrefixesView(ObjectValidationMixin, APIView): queryset = Prefix.objects.all() - serializer_class = serializers.PrefixSerializer # for drf-spectacular @extend_schema(methods=["get"], responses={200: serializers.AvailablePrefixSerializer(many=True)}) def get(self, request, pk): @@ -352,10 +356,15 @@ class AvailablePrefixesView(ObjectValidationMixin, APIView): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def get_serializer_class(self): + if self.request.method == "GET": + return serializers.AvailablePrefixSerializer + + return serializers.PrefixLengthSerializer + class AvailableIPAddressesView(ObjectValidationMixin, APIView): queryset = IPAddress.objects.all() - serializer_class = serializers.IPAddressSerializer # for drf-spectacular def get_parent(self, request, pk): raise NotImplemented() @@ -424,6 +433,12 @@ class AvailableIPAddressesView(ObjectValidationMixin, APIView): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def get_serializer_class(self): + if self.request.method == "GET": + return serializers.AvailableIPSerializer + + return serializers.IPAddressSerializer + class PrefixAvailableIPAddressesView(AvailableIPAddressesView): @@ -439,7 +454,6 @@ class IPRangeAvailableIPAddressesView(AvailableIPAddressesView): class AvailableVLANsView(ObjectValidationMixin, APIView): queryset = VLAN.objects.all() - serializer_class = serializers.VLANSerializer # for drf-spectacular @extend_schema(methods=["get"], responses={200: serializers.AvailableVLANSerializer(many=True)}) def get(self, request, pk): @@ -506,3 +520,9 @@ class AvailableVLANsView(ObjectValidationMixin, APIView): return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def get_serializer_class(self): + if self.request.method == "GET": + return serializers.AvailableVLANSerializer + + return serializers.VLANSerializer diff --git a/netbox/netbox/api/fields.py b/netbox/netbox/api/fields.py index 347ed55bd..ed70c28ac 100644 --- a/netbox/netbox/api/fields.py +++ b/netbox/netbox/api/fields.py @@ -14,6 +14,7 @@ __all__ = ( ) +@extend_schema_field(OpenApiTypes.STR) class ChoiceField(serializers.Field): """ Represent a ChoiceField as {'value': , 'label': }. Accepts a single value on write. diff --git a/netbox/users/api/views.py b/netbox/users/api/views.py index d26c648ac..04b3ae336 100644 --- a/netbox/users/api/views.py +++ b/netbox/users/api/views.py @@ -69,9 +69,8 @@ class TokenProvisionView(APIView): Non-authenticated REST API endpoint via which a user may create a Token. """ permission_classes = [] - serializer_class = serializers.TokenSerializer # for drf-spectacular - @extend_schema(methods=["post"], responses={201: serializers.TokenSerializer}) + # @extend_schema(methods=["post"], responses={201: serializers.TokenSerializer}) def post(self, request): serializer = serializers.TokenProvisionSerializer(data=request.data) serializer.is_valid() @@ -94,6 +93,9 @@ class TokenProvisionView(APIView): return Response(data, status=HTTP_201_CREATED) + def get_serializer_class(self): + return serializers.TokenSerializer + # # ObjectPermissions