issue_13422: Ignore MPTT fields in serialization if being used by staged changes

This commit is contained in:
Alex Gittings 2024-04-24 11:10:44 +01:00
parent 85db007ff5
commit ea3550ff07
3 changed files with 16 additions and 7 deletions

View File

@ -116,7 +116,7 @@ class checkout:
# Creating a new object
if kwargs.get('created'):
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False)
data = serialize_object(instance, resolve_tags=False, mptt=True)
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
return
@ -127,13 +127,13 @@ class checkout:
# Object has already been created/updated in the queue; update its queued representation
if key in self.queue:
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False)
data = serialize_object(instance, resolve_tags=False, mptt=True)
self.queue[key] = (self.queue[key][0], data)
return
# Modifying an existing object for the first time
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False)
data = serialize_object(instance, resolve_tags=False, mptt=True)
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
def pre_delete_handler(self, sender, instance, **kwargs):

View File

@ -1,6 +1,7 @@
from django.test import TransactionTestCase
from circuits.models import Provider, Circuit, CircuitType
from dcim.models import Location
from extras.choices import ChangeActionChoices
from extras.models import Branch, StagedChange, Tag
from ipam.models import ASN, RIR
@ -53,21 +54,27 @@ class StagingTestCase(TransactionTestCase):
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
circuit.tags.set(tags)
# Test MPTT Model
location = Location.objects.create(name='Location 1', slug='location-1')
# Sanity-checking
self.assertEqual(Provider.objects.count(), 4)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 10)
self.assertListEqual(list(circuit.tags.all()), list(tags))
self.assertEqual(Location.objects.count(), 1)
# Verify that changes have been rolled back after exiting the context
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(StagedChange.objects.count(), 5)
self.assertEqual(Location.objects.count(), 0)
# Verify that changes are replayed upon entering the context
with checkout(branch):
self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10)
self.assertEqual(Location.objects.count(), 1)
provider = Provider.objects.get(name='Provider D')
self.assertListEqual(list(provider.asns.all()), list(asns))
circuit = Circuit.objects.get(cid='Circuit D1')
@ -82,6 +89,7 @@ class StagingTestCase(TransactionTestCase):
circuit = Circuit.objects.get(cid='Circuit D1')
self.assertListEqual(list(circuit.tags.all()), list(tags))
self.assertEqual(StagedChange.objects.count(), 0)
self.assertEqual(Location.objects.count(), 1)
def test_object_modification(self):
branch = Branch.objects.create(name='Branch 1')

View File

@ -147,7 +147,7 @@ def count_related(model, field):
return Coalesce(subquery, 0)
def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
def serialize_object(obj, resolve_tags=True, extra=None, exclude=None, mptt=False):
"""
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
@ -166,9 +166,10 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
exclude = exclude or []
# Exclude any MPTTModel fields
if issubclass(obj.__class__, MPTTModel):
for field in ['level', 'lft', 'rght', 'tree_id']:
data.pop(field)
if not mptt:
if issubclass(obj.__class__, MPTTModel):
for field in ['level', 'lft', 'rght', 'tree_id']:
data.pop(field)
# Include custom_field_data as "custom_fields"
if hasattr(obj, 'custom_field_data'):