From b57616e9ce5c953f574ee608d519a11537c2b804 Mon Sep 17 00:00:00 2001 From: jeremystretch Date: Fri, 11 Nov 2022 10:39:48 -0500 Subject: [PATCH] Incorporate M2M changes --- netbox/netbox/staging.py | 52 +++++++++++++++++------------ netbox/netbox/tests/test_staging.py | 25 ++++++++++++++ netbox/utilities/utils.py | 11 ++++-- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/netbox/netbox/staging.py b/netbox/netbox/staging.py index 1510917e5..11a35d427 100644 --- a/netbox/netbox/staging.py +++ b/netbox/netbox/staging.py @@ -2,7 +2,7 @@ import logging from django.contrib.contenttypes.models import ContentType from django.db import transaction -from django.db.models.signals import pre_delete, post_save +from django.db.models.signals import m2m_changed, pre_delete, post_save from extras.choices import ChangeActionChoices from extras.models import Change @@ -47,6 +47,7 @@ class checkout: # Connect signal handlers logger.debug("Connecting signal handlers") post_save.connect(self.post_save_handler) + m2m_changed.connect(self.post_save_handler) pre_delete.connect(self.pre_delete_handler) def __exit__(self, exc_type, exc_val, exc_tb): @@ -54,6 +55,7 @@ class checkout: # Disconnect signal handlers logger.debug("Disconnecting signal handlers") post_save.disconnect(self.post_save_handler) + m2m_changed.disconnect(self.post_save_handler) pre_delete.disconnect(self.pre_delete_handler) # Roll back the transaction to return the database to its original state @@ -87,10 +89,7 @@ class checkout: for key, change in self.queue.items(): logger.debug(f' {key}: {change}') object_type, pk = key - action, instance = change - data = None - if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE): - data = serialize_object(instance) + action, data = change changes.append(Change( branch=self.branch, @@ -107,25 +106,35 @@ class checkout: # Signal handlers # - def post_save_handler(self, sender, instance, created, **kwargs): + def post_save_handler(self, sender, instance, **kwargs): """ Hooks to the post_save signal when a branch is active to queue create and update actions. """ key = self.get_key_for_instance(instance) object_type = instance._meta.verbose_name - if created: - # Creating a new object + # Creating a new object + if kwargs.get('created'): logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})") - self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) - elif key in self.queue: - # Object has already been created/updated at least once + data = serialize_object(instance, resolve_tags=False) + self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data) + return + + # Ignore pre_* many-to-many actions + if 'action' in kwargs and kwargs['action'] not in ('post_add', 'post_remove', 'post_clear'): + return + + # Object has already been created/updated in the queue; update its queued representation + if key in self.queue: logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})") - self.queue[key] = (self.queue[key][0], instance) - else: - # Modifying an existing object - logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})") - self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance) + data = serialize_object(instance, resolve_tags=False) + self.queue[key] = (self.queue[key][0], data) + return + + # Modifying an existing object for the first time + logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})") + data = serialize_object(instance, resolve_tags=False) + self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data) def pre_delete_handler(self, sender, instance, **kwargs): """ @@ -134,11 +143,12 @@ class checkout: key = self.get_key_for_instance(instance) object_type = instance._meta.verbose_name + # Cancel the creation of a new object if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE: - # Cancel the creation of a new object logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})") del self.queue[key] - else: - # Delete an existing object - logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})") - self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance) + return + + # Delete an existing object + logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})") + self.queue[key] = (ChangeActionChoices.ACTION_DELETE, None) diff --git a/netbox/netbox/tests/test_staging.py b/netbox/netbox/tests/test_staging.py index 3e48dfd8d..44f9d9a32 100644 --- a/netbox/netbox/tests/test_staging.py +++ b/netbox/netbox/tests/test_staging.py @@ -3,6 +3,7 @@ from django.test import TransactionTestCase from circuits.models import Provider, Circuit, CircuitType from extras.choices import ChangeActionChoices from extras.models import Branch, Change, Tag +from ipam.models import ASN, RIR from netbox.staging import checkout from utilities.testing import create_tags @@ -12,6 +13,14 @@ class StagingTestCase(TransactionTestCase): def setUp(self): create_tags('Alpha', 'Bravo', 'Charlie') + rir = RIR.objects.create(name='RIR 1', slug='rir-1') + asns = ( + ASN(asn=65001, rir=rir), + ASN(asn=65002, rir=rir), + ASN(asn=65003, rir=rir), + ) + ASN.objects.bulk_create(asns) + providers = ( Provider(name='Provider A', slug='provider-a'), Provider(name='Provider B', slug='provider-b'), @@ -36,14 +45,17 @@ class StagingTestCase(TransactionTestCase): def test_object_creation(self): branch = Branch.objects.create(name='Branch 1') tags = Tag.objects.all() + asns = ASN.objects.all() with checkout(branch): provider = Provider.objects.create(name='Provider D', slug='provider-d') + provider.asns.set(asns) circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first()) circuit.tags.set(tags) # Sanity-checking self.assertEqual(Provider.objects.count(), 4) + self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertEqual(Circuit.objects.count(), 10) self.assertListEqual(list(circuit.tags.all()), list(tags)) @@ -56,6 +68,8 @@ class StagingTestCase(TransactionTestCase): with checkout(branch): self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Circuit.objects.count(), 10) + provider = Provider.objects.get(name='Provider D') + self.assertListEqual(list(provider.asns.all()), list(asns)) circuit = Circuit.objects.get(cid='Circuit D1') self.assertListEqual(list(circuit.tags.all()), list(tags)) @@ -63,6 +77,8 @@ class StagingTestCase(TransactionTestCase): branch.merge() self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Circuit.objects.count(), 10) + provider = Provider.objects.get(name='Provider D') + self.assertListEqual(list(provider.asns.all()), list(asns)) circuit = Circuit.objects.get(cid='Circuit D1') self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertEqual(Change.objects.count(), 0) @@ -70,11 +86,13 @@ class StagingTestCase(TransactionTestCase): def test_object_modification(self): branch = Branch.objects.create(name='Branch 1') tags = Tag.objects.all() + asns = ASN.objects.all() with checkout(branch): provider = Provider.objects.get(name='Provider A') provider.name = 'Provider X' provider.save() + provider.asns.set(asns) circuit = Circuit.objects.get(cid='Circuit A1') circuit.cid = 'Circuit X' circuit.save() @@ -83,6 +101,7 @@ class StagingTestCase(TransactionTestCase): # Sanity-checking self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') + self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X') self.assertListEqual(list(circuit.tags.all()), list(tags)) @@ -90,6 +109,8 @@ class StagingTestCase(TransactionTestCase): # Verify that changes have been rolled back after exiting the context self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A') + provider = Provider.objects.get(pk=provider.pk) + self.assertListEqual(list(provider.asns.all()), []) self.assertEqual(Circuit.objects.count(), 9) circuit = Circuit.objects.get(pk=circuit.pk) self.assertEqual(circuit.cid, 'Circuit A1') @@ -100,6 +121,8 @@ class StagingTestCase(TransactionTestCase): with checkout(branch): self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') + provider = Provider.objects.get(pk=provider.pk) + self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertEqual(Circuit.objects.count(), 9) circuit = Circuit.objects.get(pk=circuit.pk) self.assertEqual(circuit.cid, 'Circuit X') @@ -109,6 +132,8 @@ class StagingTestCase(TransactionTestCase): branch.merge() self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') + provider = Provider.objects.get(pk=provider.pk) + self.assertListEqual(list(provider.asns.all()), list(asns)) self.assertEqual(Circuit.objects.count(), 9) circuit = Circuit.objects.get(pk=circuit.pk) self.assertEqual(circuit.cid, 'Circuit X') diff --git a/netbox/utilities/utils.py b/netbox/utilities/utils.py index de261945c..a26940ac1 100644 --- a/netbox/utilities/utils.py +++ b/netbox/utilities/utils.py @@ -136,7 +136,7 @@ def count_related(model, field): return Coalesce(subquery, 0) -def serialize_object(obj, extra=None): +def serialize_object(obj, resolve_tags=True, extra=None): """ Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys @@ -155,8 +155,9 @@ def serialize_object(obj, extra=None): if hasattr(obj, 'custom_field_data'): data['custom_fields'] = data.pop('custom_field_data') - # Include any tags. Check for tags cached on the instance; fall back to using the manager. - if is_taggable(obj): + # Resolve any assigned tags to their names. Check for tags cached on the instance; + # fall back to using the manager. + if resolve_tags and is_taggable(obj): tags = getattr(obj, '_tags', None) or obj.tags.all() data['tags'] = sorted([tag.name for tag in tags]) @@ -174,6 +175,10 @@ def serialize_object(obj, extra=None): def deserialize_object(model, fields, pk=None): + """ + Instantiate an object from the given model and field data. Functions as + the complement to serialize_object(). + """ content_type = ContentType.objects.get_for_model(model) if 'custom_fields' in fields: fields['custom_field_data'] = fields.pop('custom_fields')