mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-13 19:18:16 -06:00
issue_13422: Ignore MPTT fields in serialization if being used by staged changes
This commit is contained in:
parent
85db007ff5
commit
ea3550ff07
@ -116,7 +116,7 @@ class checkout:
|
|||||||
# Creating a new object
|
# Creating a new object
|
||||||
if kwargs.get('created'):
|
if kwargs.get('created'):
|
||||||
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
|
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
|
||||||
data = serialize_object(instance, resolve_tags=False)
|
data = serialize_object(instance, resolve_tags=False, mptt=True)
|
||||||
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
|
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -127,13 +127,13 @@ class checkout:
|
|||||||
# Object has already been created/updated in the queue; update its queued representation
|
# Object has already been created/updated in the queue; update its queued representation
|
||||||
if key in self.queue:
|
if key in self.queue:
|
||||||
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
|
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
|
||||||
data = serialize_object(instance, resolve_tags=False)
|
data = serialize_object(instance, resolve_tags=False, mptt=True)
|
||||||
self.queue[key] = (self.queue[key][0], data)
|
self.queue[key] = (self.queue[key][0], data)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Modifying an existing object for the first time
|
# Modifying an existing object for the first time
|
||||||
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
|
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
|
||||||
data = serialize_object(instance, resolve_tags=False)
|
data = serialize_object(instance, resolve_tags=False, mptt=True)
|
||||||
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
|
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
|
||||||
|
|
||||||
def pre_delete_handler(self, sender, instance, **kwargs):
|
def pre_delete_handler(self, sender, instance, **kwargs):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
from circuits.models import Provider, Circuit, CircuitType
|
from circuits.models import Provider, Circuit, CircuitType
|
||||||
|
from dcim.models import Location
|
||||||
from extras.choices import ChangeActionChoices
|
from extras.choices import ChangeActionChoices
|
||||||
from extras.models import Branch, StagedChange, Tag
|
from extras.models import Branch, StagedChange, Tag
|
||||||
from ipam.models import ASN, RIR
|
from ipam.models import ASN, RIR
|
||||||
@ -53,21 +54,27 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
|
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
|
||||||
circuit.tags.set(tags)
|
circuit.tags.set(tags)
|
||||||
|
|
||||||
|
# Test MPTT Model
|
||||||
|
location = Location.objects.create(name='Location 1', slug='location-1')
|
||||||
|
|
||||||
# Sanity-checking
|
# Sanity-checking
|
||||||
self.assertEqual(Provider.objects.count(), 4)
|
self.assertEqual(Provider.objects.count(), 4)
|
||||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
self.assertEqual(Circuit.objects.count(), 10)
|
self.assertEqual(Circuit.objects.count(), 10)
|
||||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||||
|
self.assertEqual(Location.objects.count(), 1)
|
||||||
|
|
||||||
# Verify that changes have been rolled back after exiting the context
|
# Verify that changes have been rolled back after exiting the context
|
||||||
self.assertEqual(Provider.objects.count(), 3)
|
self.assertEqual(Provider.objects.count(), 3)
|
||||||
self.assertEqual(Circuit.objects.count(), 9)
|
self.assertEqual(Circuit.objects.count(), 9)
|
||||||
self.assertEqual(StagedChange.objects.count(), 5)
|
self.assertEqual(StagedChange.objects.count(), 5)
|
||||||
|
self.assertEqual(Location.objects.count(), 0)
|
||||||
|
|
||||||
# Verify that changes are replayed upon entering the context
|
# Verify that changes are replayed upon entering the context
|
||||||
with checkout(branch):
|
with checkout(branch):
|
||||||
self.assertEqual(Provider.objects.count(), 4)
|
self.assertEqual(Provider.objects.count(), 4)
|
||||||
self.assertEqual(Circuit.objects.count(), 10)
|
self.assertEqual(Circuit.objects.count(), 10)
|
||||||
|
self.assertEqual(Location.objects.count(), 1)
|
||||||
provider = Provider.objects.get(name='Provider D')
|
provider = Provider.objects.get(name='Provider D')
|
||||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
circuit = Circuit.objects.get(cid='Circuit D1')
|
circuit = Circuit.objects.get(cid='Circuit D1')
|
||||||
@ -82,6 +89,7 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
circuit = Circuit.objects.get(cid='Circuit D1')
|
circuit = Circuit.objects.get(cid='Circuit D1')
|
||||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||||
self.assertEqual(StagedChange.objects.count(), 0)
|
self.assertEqual(StagedChange.objects.count(), 0)
|
||||||
|
self.assertEqual(Location.objects.count(), 1)
|
||||||
|
|
||||||
def test_object_modification(self):
|
def test_object_modification(self):
|
||||||
branch = Branch.objects.create(name='Branch 1')
|
branch = Branch.objects.create(name='Branch 1')
|
||||||
|
@ -147,7 +147,7 @@ def count_related(model, field):
|
|||||||
return Coalesce(subquery, 0)
|
return Coalesce(subquery, 0)
|
||||||
|
|
||||||
|
|
||||||
def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
|
def serialize_object(obj, resolve_tags=True, extra=None, exclude=None, mptt=False):
|
||||||
"""
|
"""
|
||||||
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
|
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
|
||||||
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
|
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
|
||||||
@ -166,9 +166,10 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
|
|||||||
exclude = exclude or []
|
exclude = exclude or []
|
||||||
|
|
||||||
# Exclude any MPTTModel fields
|
# Exclude any MPTTModel fields
|
||||||
if issubclass(obj.__class__, MPTTModel):
|
if not mptt:
|
||||||
for field in ['level', 'lft', 'rght', 'tree_id']:
|
if issubclass(obj.__class__, MPTTModel):
|
||||||
data.pop(field)
|
for field in ['level', 'lft', 'rght', 'tree_id']:
|
||||||
|
data.pop(field)
|
||||||
|
|
||||||
# Include custom_field_data as "custom_fields"
|
# Include custom_field_data as "custom_fields"
|
||||||
if hasattr(obj, 'custom_field_data'):
|
if hasattr(obj, 'custom_field_data'):
|
||||||
|
Loading…
Reference in New Issue
Block a user