diff --git a/netbox/netbox/staging.py b/netbox/netbox/staging.py index ec38dcadc..3c32a78f1 100644 --- a/netbox/netbox/staging.py +++ b/netbox/netbox/staging.py @@ -116,7 +116,7 @@ class checkout: # Creating a new object if kwargs.get('created'): logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})") - data = serialize_object(instance, resolve_tags=False) + data = serialize_object(instance, resolve_tags=False, mptt=True) self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data) return @@ -127,13 +127,13 @@ class checkout: # Object has already been created/updated in the queue; update its queued representation if key in self.queue: logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})") - data = serialize_object(instance, resolve_tags=False) + data = serialize_object(instance, resolve_tags=False, mptt=True) self.queue[key] = (self.queue[key][0], data) return # Modifying an existing object for the first time logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})") - data = serialize_object(instance, resolve_tags=False) + data = serialize_object(instance, resolve_tags=False, mptt=True) self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data) def pre_delete_handler(self, sender, instance, **kwargs): diff --git a/netbox/netbox/tests/test_staging.py b/netbox/netbox/tests/test_staging.py index ed3a69f10..d655a8660 100644 --- a/netbox/netbox/tests/test_staging.py +++ b/netbox/netbox/tests/test_staging.py @@ -1,6 +1,7 @@ from django.test import TransactionTestCase from circuits.models import Provider, Circuit, CircuitType +from dcim.models import Location from extras.choices import ChangeActionChoices from extras.models import Branch, StagedChange, Tag from ipam.models import ASN, RIR @@ -53,21 +54,27 @@ class StagingTestCase(TransactionTestCase): circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first()) circuit.tags.set(tags) + # Test MPTT Model + location = Location.objects.create(name='Location 1', slug='location-1') + # Sanity-checking self.assertEqual(Provider.objects.count(), 4) self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertEqual(Circuit.objects.count(), 10) self.assertListEqual(list(circuit.tags.all()), list(tags)) + self.assertEqual(Location.objects.count(), 1) # Verify that changes have been rolled back after exiting the context self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(StagedChange.objects.count(), 5) + self.assertEqual(Location.objects.count(), 0) # Verify that changes are replayed upon entering the context with checkout(branch): self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Circuit.objects.count(), 10) + self.assertEqual(Location.objects.count(), 1) provider = Provider.objects.get(name='Provider D') self.assertListEqual(list(provider.asns.all()), list(asns)) circuit = Circuit.objects.get(cid='Circuit D1') @@ -82,6 +89,7 @@ class StagingTestCase(TransactionTestCase): circuit = Circuit.objects.get(cid='Circuit D1') self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertEqual(StagedChange.objects.count(), 0) + self.assertEqual(Location.objects.count(), 1) def test_object_modification(self): branch = Branch.objects.create(name='Branch 1') diff --git a/netbox/utilities/utils.py b/netbox/utilities/utils.py index 4a6db9093..eeec3abf5 100644 --- a/netbox/utilities/utils.py +++ b/netbox/utilities/utils.py @@ -147,7 +147,7 @@ def count_related(model, field): return Coalesce(subquery, 0) -def serialize_object(obj, resolve_tags=True, extra=None, exclude=None): +def serialize_object(obj, resolve_tags=True, extra=None, exclude=None, mptt=False): """ Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys @@ -166,9 +166,10 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None): exclude = exclude or [] # Exclude any MPTTModel fields - if issubclass(obj.__class__, MPTTModel): - for field in ['level', 'lft', 'rght', 'tree_id']: - data.pop(field) + if not mptt: + if issubclass(obj.__class__, MPTTModel): + for field in ['level', 'lft', 'rght', 'tree_id']: + data.pop(field) # Include custom_field_data as "custom_fields" if hasattr(obj, 'custom_field_data'):