diff --git a/netbox/netbox/api/viewsets/__init__.py b/netbox/netbox/api/viewsets/__init__.py index ea2195990..57bceb674 100644 --- a/netbox/netbox/api/viewsets/__init__.py +++ b/netbox/netbox/api/viewsets/__init__.py @@ -170,6 +170,28 @@ class NetBoxModelViewSet( # Creates + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + bulk_create = getattr(serializer, 'many', False) + self.perform_create(serializer) + + # After creating the instance(s), re-initialize the serializer with a queryset + # to ensure related objects are prefetched. + if bulk_create: + instance_pks = [obj.pk for obj in serializer.instance] + # Order by PK to ensure that the ordering of objects in the response + # matches the ordering of those in the request. + qs = self.get_queryset().filter(pk__in=instance_pks).order_by('pk') + else: + qs = self.get_queryset().get(pk=serializer.instance.pk) + + # Re-serialize the instance(s) with prefetched data + serializer = self.get_serializer(qs, many=bulk_create) + + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + def perform_create(self, serializer): model = self.queryset.model logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}') @@ -186,9 +208,20 @@ class NetBoxModelViewSet( # Updates def update(self, request, *args, **kwargs): - # Hotwire get_object() to ensure we save a pre-change snapshot - self.get_object = self.get_object_with_snapshot - return super().update(request, *args, **kwargs) + partial = kwargs.pop('partial', False) + instance = self.get_object_with_snapshot() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + # After updating the instance, re-initialize the serializer with a queryset + # to ensure related objects are prefetched. + qs = self.get_queryset().get(pk=serializer.instance.pk) + + # Re-serialize the instance(s) with prefetched data + serializer = self.get_serializer(qs) + + return Response(serializer.data) def perform_update(self, serializer): model = self.queryset.model diff --git a/netbox/netbox/api/viewsets/mixins.py b/netbox/netbox/api/viewsets/mixins.py index e74488164..7f753240e 100644 --- a/netbox/netbox/api/viewsets/mixins.py +++ b/netbox/netbox/api/viewsets/mixins.py @@ -108,13 +108,17 @@ class BulkUpdateModelMixin: obj.pop('id'): obj for obj in request.data } - data = self.perform_bulk_update(qs, update_data, partial=partial) + object_pks = self.perform_bulk_update(qs, update_data, partial=partial) - return Response(data, status=status.HTTP_200_OK) + # Prefetch related objects for all updated instances + qs = self.get_queryset().filter(pk__in=object_pks) + serializer = self.get_serializer(qs, many=True) + + return Response(serializer.data, status=status.HTTP_200_OK) def perform_bulk_update(self, objects, update_data, partial): + updated_pks = [] with transaction.atomic(using=router.db_for_write(self.queryset.model)): - data_list = [] for obj in objects: data = update_data.get(obj.id) if hasattr(obj, 'snapshot'): @@ -122,9 +126,9 @@ class BulkUpdateModelMixin: serializer = self.get_serializer(obj, data=data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) - data_list.append(serializer.data) + updated_pks.append(obj.pk) - return data_list + return updated_pks def bulk_partial_update(self, request, *args, **kwargs): kwargs['partial'] = True