Closes #21263: Prefetch related objects after creating/updating objects via REST API (#21329)

* Closes #21263: Prefetch related objects after creating/updating objects via REST API

* Add comment re: ordering by PK
This commit is contained in:
Jeremy Stretch
2026-01-30 14:13:05 -05:00
committed by GitHub
parent bec5ecf6a9
commit ad29cb2d66
2 changed files with 45 additions and 8 deletions
+36 -3
View File
@@ -170,6 +170,28 @@ class NetBoxModelViewSet(
# Creates # Creates
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
bulk_create = getattr(serializer, 'many', False)
self.perform_create(serializer)
# After creating the instance(s), re-initialize the serializer with a queryset
# to ensure related objects are prefetched.
if bulk_create:
instance_pks = [obj.pk for obj in serializer.instance]
# Order by PK to ensure that the ordering of objects in the response
# matches the ordering of those in the request.
qs = self.get_queryset().filter(pk__in=instance_pks).order_by('pk')
else:
qs = self.get_queryset().get(pk=serializer.instance.pk)
# Re-serialize the instance(s) with prefetched data
serializer = self.get_serializer(qs, many=bulk_create)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer): def perform_create(self, serializer):
model = self.queryset.model model = self.queryset.model
logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}') logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}')
@@ -186,9 +208,20 @@ class NetBoxModelViewSet(
# Updates # Updates
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
# Hotwire get_object() to ensure we save a pre-change snapshot partial = kwargs.pop('partial', False)
self.get_object = self.get_object_with_snapshot instance = self.get_object_with_snapshot()
return super().update(request, *args, **kwargs) serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
# After updating the instance, re-initialize the serializer with a queryset
# to ensure related objects are prefetched.
qs = self.get_queryset().get(pk=serializer.instance.pk)
# Re-serialize the instance(s) with prefetched data
serializer = self.get_serializer(qs)
return Response(serializer.data)
def perform_update(self, serializer): def perform_update(self, serializer):
model = self.queryset.model model = self.queryset.model
+9 -5
View File
@@ -108,13 +108,17 @@ class BulkUpdateModelMixin:
obj.pop('id'): obj for obj in request.data obj.pop('id'): obj for obj in request.data
} }
data = self.perform_bulk_update(qs, update_data, partial=partial) object_pks = self.perform_bulk_update(qs, update_data, partial=partial)
return Response(data, status=status.HTTP_200_OK) # Prefetch related objects for all updated instances
qs = self.get_queryset().filter(pk__in=object_pks)
serializer = self.get_serializer(qs, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
def perform_bulk_update(self, objects, update_data, partial): def perform_bulk_update(self, objects, update_data, partial):
updated_pks = []
with transaction.atomic(using=router.db_for_write(self.queryset.model)): with transaction.atomic(using=router.db_for_write(self.queryset.model)):
data_list = []
for obj in objects: for obj in objects:
data = update_data.get(obj.id) data = update_data.get(obj.id)
if hasattr(obj, 'snapshot'): if hasattr(obj, 'snapshot'):
@@ -122,9 +126,9 @@ class BulkUpdateModelMixin:
serializer = self.get_serializer(obj, data=data, partial=partial) serializer = self.get_serializer(obj, data=data, partial=partial)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_update(serializer) self.perform_update(serializer)
data_list.append(serializer.data) updated_pks.append(obj.pk)
return data_list return updated_pks
def bulk_partial_update(self, request, *args, **kwargs): def bulk_partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True kwargs['partial'] = True