mirror of
https://github.com/netbox-community/netbox.git
synced 2026-02-04 06:16:23 -06:00
* Closes #21263: Prefetch related objects after creating/updating objects via REST API * Add comment re: ordering by PK
This commit is contained in:
@@ -170,6 +170,28 @@ class NetBoxModelViewSet(
|
|||||||
|
|
||||||
# Creates
|
# Creates
|
||||||
|
|
||||||
|
def create(self, request, *args, **kwargs):
|
||||||
|
serializer = self.get_serializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
bulk_create = getattr(serializer, 'many', False)
|
||||||
|
self.perform_create(serializer)
|
||||||
|
|
||||||
|
# After creating the instance(s), re-initialize the serializer with a queryset
|
||||||
|
# to ensure related objects are prefetched.
|
||||||
|
if bulk_create:
|
||||||
|
instance_pks = [obj.pk for obj in serializer.instance]
|
||||||
|
# Order by PK to ensure that the ordering of objects in the response
|
||||||
|
# matches the ordering of those in the request.
|
||||||
|
qs = self.get_queryset().filter(pk__in=instance_pks).order_by('pk')
|
||||||
|
else:
|
||||||
|
qs = self.get_queryset().get(pk=serializer.instance.pk)
|
||||||
|
|
||||||
|
# Re-serialize the instance(s) with prefetched data
|
||||||
|
serializer = self.get_serializer(qs, many=bulk_create)
|
||||||
|
|
||||||
|
headers = self.get_success_headers(serializer.data)
|
||||||
|
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||||
|
|
||||||
def perform_create(self, serializer):
|
def perform_create(self, serializer):
|
||||||
model = self.queryset.model
|
model = self.queryset.model
|
||||||
logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}')
|
logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}')
|
||||||
@@ -186,9 +208,20 @@ class NetBoxModelViewSet(
|
|||||||
# Updates
|
# Updates
|
||||||
|
|
||||||
def update(self, request, *args, **kwargs):
|
def update(self, request, *args, **kwargs):
|
||||||
# Hotwire get_object() to ensure we save a pre-change snapshot
|
partial = kwargs.pop('partial', False)
|
||||||
self.get_object = self.get_object_with_snapshot
|
instance = self.get_object_with_snapshot()
|
||||||
return super().update(request, *args, **kwargs)
|
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
self.perform_update(serializer)
|
||||||
|
|
||||||
|
# After updating the instance, re-initialize the serializer with a queryset
|
||||||
|
# to ensure related objects are prefetched.
|
||||||
|
qs = self.get_queryset().get(pk=serializer.instance.pk)
|
||||||
|
|
||||||
|
# Re-serialize the instance(s) with prefetched data
|
||||||
|
serializer = self.get_serializer(qs)
|
||||||
|
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
def perform_update(self, serializer):
|
def perform_update(self, serializer):
|
||||||
model = self.queryset.model
|
model = self.queryset.model
|
||||||
|
|||||||
@@ -108,13 +108,17 @@ class BulkUpdateModelMixin:
|
|||||||
obj.pop('id'): obj for obj in request.data
|
obj.pop('id'): obj for obj in request.data
|
||||||
}
|
}
|
||||||
|
|
||||||
data = self.perform_bulk_update(qs, update_data, partial=partial)
|
object_pks = self.perform_bulk_update(qs, update_data, partial=partial)
|
||||||
|
|
||||||
return Response(data, status=status.HTTP_200_OK)
|
# Prefetch related objects for all updated instances
|
||||||
|
qs = self.get_queryset().filter(pk__in=object_pks)
|
||||||
|
serializer = self.get_serializer(qs, many=True)
|
||||||
|
|
||||||
|
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||||
|
|
||||||
def perform_bulk_update(self, objects, update_data, partial):
|
def perform_bulk_update(self, objects, update_data, partial):
|
||||||
|
updated_pks = []
|
||||||
with transaction.atomic(using=router.db_for_write(self.queryset.model)):
|
with transaction.atomic(using=router.db_for_write(self.queryset.model)):
|
||||||
data_list = []
|
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
data = update_data.get(obj.id)
|
data = update_data.get(obj.id)
|
||||||
if hasattr(obj, 'snapshot'):
|
if hasattr(obj, 'snapshot'):
|
||||||
@@ -122,9 +126,9 @@ class BulkUpdateModelMixin:
|
|||||||
serializer = self.get_serializer(obj, data=data, partial=partial)
|
serializer = self.get_serializer(obj, data=data, partial=partial)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
self.perform_update(serializer)
|
self.perform_update(serializer)
|
||||||
data_list.append(serializer.data)
|
updated_pks.append(obj.pk)
|
||||||
|
|
||||||
return data_list
|
return updated_pks
|
||||||
|
|
||||||
def bulk_partial_update(self, request, *args, **kwargs):
|
def bulk_partial_update(self, request, *args, **kwargs):
|
||||||
kwargs['partial'] = True
|
kwargs['partial'] = True
|
||||||
|
|||||||
Reference in New Issue
Block a user