issue_13422: Ignore MPTT fields in serialization if being used by staged changes

This commit is contained in:
Alex Gittings 2024-04-24 11:10:44 +01:00
parent 85db007ff5
commit ea3550ff07
3 changed files with 16 additions and 7 deletions

View File

@ -116,7 +116,7 @@ class checkout:
# Creating a new object # Creating a new object
if kwargs.get('created'): if kwargs.get('created'):
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})") logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False) data = serialize_object(instance, resolve_tags=False, mptt=True)
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data) self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
return return
@ -127,13 +127,13 @@ class checkout:
# Object has already been created/updated in the queue; update its queued representation # Object has already been created/updated in the queue; update its queued representation
if key in self.queue: if key in self.queue:
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})") logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False) data = serialize_object(instance, resolve_tags=False, mptt=True)
self.queue[key] = (self.queue[key][0], data) self.queue[key] = (self.queue[key][0], data)
return return
# Modifying an existing object for the first time # Modifying an existing object for the first time
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})") logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False) data = serialize_object(instance, resolve_tags=False, mptt=True)
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data) self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
def pre_delete_handler(self, sender, instance, **kwargs): def pre_delete_handler(self, sender, instance, **kwargs):

View File

@ -1,6 +1,7 @@
from django.test import TransactionTestCase from django.test import TransactionTestCase
from circuits.models import Provider, Circuit, CircuitType from circuits.models import Provider, Circuit, CircuitType
from dcim.models import Location
from extras.choices import ChangeActionChoices from extras.choices import ChangeActionChoices
from extras.models import Branch, StagedChange, Tag from extras.models import Branch, StagedChange, Tag
from ipam.models import ASN, RIR from ipam.models import ASN, RIR
@ -53,21 +54,27 @@ class StagingTestCase(TransactionTestCase):
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first()) circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
circuit.tags.set(tags) circuit.tags.set(tags)
# Test MPTT Model
location = Location.objects.create(name='Location 1', slug='location-1')
# Sanity-checking # Sanity-checking
self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Provider.objects.count(), 4)
self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 10) self.assertEqual(Circuit.objects.count(), 10)
self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertListEqual(list(circuit.tags.all()), list(tags))
self.assertEqual(Location.objects.count(), 1)
# Verify that changes have been rolled back after exiting the context # Verify that changes have been rolled back after exiting the context
self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(StagedChange.objects.count(), 5) self.assertEqual(StagedChange.objects.count(), 5)
self.assertEqual(Location.objects.count(), 0)
# Verify that changes are replayed upon entering the context # Verify that changes are replayed upon entering the context
with checkout(branch): with checkout(branch):
self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10) self.assertEqual(Circuit.objects.count(), 10)
self.assertEqual(Location.objects.count(), 1)
provider = Provider.objects.get(name='Provider D') provider = Provider.objects.get(name='Provider D')
self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertListEqual(list(provider.asns.all()), list(asns))
circuit = Circuit.objects.get(cid='Circuit D1') circuit = Circuit.objects.get(cid='Circuit D1')
@ -82,6 +89,7 @@ class StagingTestCase(TransactionTestCase):
circuit = Circuit.objects.get(cid='Circuit D1') circuit = Circuit.objects.get(cid='Circuit D1')
self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertListEqual(list(circuit.tags.all()), list(tags))
self.assertEqual(StagedChange.objects.count(), 0) self.assertEqual(StagedChange.objects.count(), 0)
self.assertEqual(Location.objects.count(), 1)
def test_object_modification(self): def test_object_modification(self):
branch = Branch.objects.create(name='Branch 1') branch = Branch.objects.create(name='Branch 1')

View File

@ -147,7 +147,7 @@ def count_related(model, field):
return Coalesce(subquery, 0) return Coalesce(subquery, 0)
def serialize_object(obj, resolve_tags=True, extra=None, exclude=None): def serialize_object(obj, resolve_tags=True, extra=None, exclude=None, mptt=False):
""" """
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
@ -166,9 +166,10 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
exclude = exclude or [] exclude = exclude or []
# Exclude any MPTTModel fields # Exclude any MPTTModel fields
if issubclass(obj.__class__, MPTTModel): if not mptt:
for field in ['level', 'lft', 'rght', 'tree_id']: if issubclass(obj.__class__, MPTTModel):
data.pop(field) for field in ['level', 'lft', 'rght', 'tree_id']:
data.pop(field)
# Include custom_field_data as "custom_fields" # Include custom_field_data as "custom_fields"
if hasattr(obj, 'custom_field_data'): if hasattr(obj, 'custom_field_data'):