Incorporate M2M changes

This commit is contained in:
jeremystretch 2022-11-11 10:39:48 -05:00
parent f4cec1e6c8
commit b57616e9ce
3 changed files with 64 additions and 24 deletions

View File

@ -2,7 +2,7 @@ import logging
from django.contrib.contenttypes.models import ContentType
from django.db import transaction
from django.db.models.signals import pre_delete, post_save
from django.db.models.signals import m2m_changed, pre_delete, post_save
from extras.choices import ChangeActionChoices
from extras.models import Change
@ -47,6 +47,7 @@ class checkout:
# Connect signal handlers
logger.debug("Connecting signal handlers")
post_save.connect(self.post_save_handler)
m2m_changed.connect(self.post_save_handler)
pre_delete.connect(self.pre_delete_handler)
def __exit__(self, exc_type, exc_val, exc_tb):
@ -54,6 +55,7 @@ class checkout:
# Disconnect signal handlers
logger.debug("Disconnecting signal handlers")
post_save.disconnect(self.post_save_handler)
m2m_changed.disconnect(self.post_save_handler)
pre_delete.disconnect(self.pre_delete_handler)
# Roll back the transaction to return the database to its original state
@ -87,10 +89,7 @@ class checkout:
for key, change in self.queue.items():
logger.debug(f' {key}: {change}')
object_type, pk = key
action, instance = change
data = None
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
data = serialize_object(instance)
action, data = change
changes.append(Change(
branch=self.branch,
@ -107,25 +106,35 @@ class checkout:
# Signal handlers
#
def post_save_handler(self, sender, instance, created, **kwargs):
def post_save_handler(self, sender, instance, **kwargs):
"""
Hooks to the post_save signal when a branch is active to queue create and update actions.
"""
key = self.get_key_for_instance(instance)
object_type = instance._meta.verbose_name
if created:
# Creating a new object
# Creating a new object
if kwargs.get('created'):
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance)
elif key in self.queue:
# Object has already been created/updated at least once
data = serialize_object(instance, resolve_tags=False)
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
return
# Ignore pre_* many-to-many actions
if 'action' in kwargs and kwargs['action'] not in ('post_add', 'post_remove', 'post_clear'):
return
# Object has already been created/updated in the queue; update its queued representation
if key in self.queue:
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (self.queue[key][0], instance)
else:
# Modifying an existing object
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
data = serialize_object(instance, resolve_tags=False)
self.queue[key] = (self.queue[key][0], data)
return
# Modifying an existing object for the first time
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False)
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
def pre_delete_handler(self, sender, instance, **kwargs):
"""
@ -134,11 +143,12 @@ class checkout:
key = self.get_key_for_instance(instance)
object_type = instance._meta.verbose_name
# Cancel the creation of a new object
if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE:
# Cancel the creation of a new object
logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})")
del self.queue[key]
else:
# Delete an existing object
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)
return
# Delete an existing object
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, None)

View File

@ -3,6 +3,7 @@ from django.test import TransactionTestCase
from circuits.models import Provider, Circuit, CircuitType
from extras.choices import ChangeActionChoices
from extras.models import Branch, Change, Tag
from ipam.models import ASN, RIR
from netbox.staging import checkout
from utilities.testing import create_tags
@ -12,6 +13,14 @@ class StagingTestCase(TransactionTestCase):
def setUp(self):
create_tags('Alpha', 'Bravo', 'Charlie')
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
asns = (
ASN(asn=65001, rir=rir),
ASN(asn=65002, rir=rir),
ASN(asn=65003, rir=rir),
)
ASN.objects.bulk_create(asns)
providers = (
Provider(name='Provider A', slug='provider-a'),
Provider(name='Provider B', slug='provider-b'),
@ -36,14 +45,17 @@ class StagingTestCase(TransactionTestCase):
def test_object_creation(self):
branch = Branch.objects.create(name='Branch 1')
tags = Tag.objects.all()
asns = ASN.objects.all()
with checkout(branch):
provider = Provider.objects.create(name='Provider D', slug='provider-d')
provider.asns.set(asns)
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
circuit.tags.set(tags)
# Sanity-checking
self.assertEqual(Provider.objects.count(), 4)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 10)
self.assertListEqual(list(circuit.tags.all()), list(tags))
@ -56,6 +68,8 @@ class StagingTestCase(TransactionTestCase):
with checkout(branch):
self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10)
provider = Provider.objects.get(name='Provider D')
self.assertListEqual(list(provider.asns.all()), list(asns))
circuit = Circuit.objects.get(cid='Circuit D1')
self.assertListEqual(list(circuit.tags.all()), list(tags))
@ -63,6 +77,8 @@ class StagingTestCase(TransactionTestCase):
branch.merge()
self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10)
provider = Provider.objects.get(name='Provider D')
self.assertListEqual(list(provider.asns.all()), list(asns))
circuit = Circuit.objects.get(cid='Circuit D1')
self.assertListEqual(list(circuit.tags.all()), list(tags))
self.assertEqual(Change.objects.count(), 0)
@ -70,11 +86,13 @@ class StagingTestCase(TransactionTestCase):
def test_object_modification(self):
branch = Branch.objects.create(name='Branch 1')
tags = Tag.objects.all()
asns = ASN.objects.all()
with checkout(branch):
provider = Provider.objects.get(name='Provider A')
provider.name = 'Provider X'
provider.save()
provider.asns.set(asns)
circuit = Circuit.objects.get(cid='Circuit A1')
circuit.cid = 'Circuit X'
circuit.save()
@ -83,6 +101,7 @@ class StagingTestCase(TransactionTestCase):
# Sanity-checking
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
self.assertListEqual(list(circuit.tags.all()), list(tags))
@ -90,6 +109,8 @@ class StagingTestCase(TransactionTestCase):
# Verify that changes have been rolled back after exiting the context
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
provider = Provider.objects.get(pk=provider.pk)
self.assertListEqual(list(provider.asns.all()), [])
self.assertEqual(Circuit.objects.count(), 9)
circuit = Circuit.objects.get(pk=circuit.pk)
self.assertEqual(circuit.cid, 'Circuit A1')
@ -100,6 +121,8 @@ class StagingTestCase(TransactionTestCase):
with checkout(branch):
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
provider = Provider.objects.get(pk=provider.pk)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 9)
circuit = Circuit.objects.get(pk=circuit.pk)
self.assertEqual(circuit.cid, 'Circuit X')
@ -109,6 +132,8 @@ class StagingTestCase(TransactionTestCase):
branch.merge()
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
provider = Provider.objects.get(pk=provider.pk)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 9)
circuit = Circuit.objects.get(pk=circuit.pk)
self.assertEqual(circuit.cid, 'Circuit X')

View File

@ -136,7 +136,7 @@ def count_related(model, field):
return Coalesce(subquery, 0)
def serialize_object(obj, extra=None):
def serialize_object(obj, resolve_tags=True, extra=None):
"""
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
@ -155,8 +155,9 @@ def serialize_object(obj, extra=None):
if hasattr(obj, 'custom_field_data'):
data['custom_fields'] = data.pop('custom_field_data')
# Include any tags. Check for tags cached on the instance; fall back to using the manager.
if is_taggable(obj):
# Resolve any assigned tags to their names. Check for tags cached on the instance;
# fall back to using the manager.
if resolve_tags and is_taggable(obj):
tags = getattr(obj, '_tags', None) or obj.tags.all()
data['tags'] = sorted([tag.name for tag in tags])
@ -174,6 +175,10 @@ def serialize_object(obj, extra=None):
def deserialize_object(model, fields, pk=None):
"""
Instantiate an object from the given model and field data. Functions as
the complement to serialize_object().
"""
content_type = ContentType.objects.get_for_model(model)
if 'custom_fields' in fields:
fields['custom_field_data'] = fields.pop('custom_fields')