Incorporate M2M changes

This commit is contained in:
jeremystretch 2022-11-11 10:39:48 -05:00
parent f4cec1e6c8
commit b57616e9ce
3 changed files with 64 additions and 24 deletions

View File

@ -2,7 +2,7 @@ import logging
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import transaction from django.db import transaction
from django.db.models.signals import pre_delete, post_save from django.db.models.signals import m2m_changed, pre_delete, post_save
from extras.choices import ChangeActionChoices from extras.choices import ChangeActionChoices
from extras.models import Change from extras.models import Change
@ -47,6 +47,7 @@ class checkout:
# Connect signal handlers # Connect signal handlers
logger.debug("Connecting signal handlers") logger.debug("Connecting signal handlers")
post_save.connect(self.post_save_handler) post_save.connect(self.post_save_handler)
m2m_changed.connect(self.post_save_handler)
pre_delete.connect(self.pre_delete_handler) pre_delete.connect(self.pre_delete_handler)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
@ -54,6 +55,7 @@ class checkout:
# Disconnect signal handlers # Disconnect signal handlers
logger.debug("Disconnecting signal handlers") logger.debug("Disconnecting signal handlers")
post_save.disconnect(self.post_save_handler) post_save.disconnect(self.post_save_handler)
m2m_changed.disconnect(self.post_save_handler)
pre_delete.disconnect(self.pre_delete_handler) pre_delete.disconnect(self.pre_delete_handler)
# Roll back the transaction to return the database to its original state # Roll back the transaction to return the database to its original state
@ -87,10 +89,7 @@ class checkout:
for key, change in self.queue.items(): for key, change in self.queue.items():
logger.debug(f' {key}: {change}') logger.debug(f' {key}: {change}')
object_type, pk = key object_type, pk = key
action, instance = change action, data = change
data = None
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
data = serialize_object(instance)
changes.append(Change( changes.append(Change(
branch=self.branch, branch=self.branch,
@ -107,25 +106,35 @@ class checkout:
# Signal handlers # Signal handlers
# #
def post_save_handler(self, sender, instance, created, **kwargs): def post_save_handler(self, sender, instance, **kwargs):
""" """
Hooks to the post_save signal when a branch is active to queue create and update actions. Hooks to the post_save signal when a branch is active to queue create and update actions.
""" """
key = self.get_key_for_instance(instance) key = self.get_key_for_instance(instance)
object_type = instance._meta.verbose_name object_type = instance._meta.verbose_name
if created: # Creating a new object
# Creating a new object if kwargs.get('created'):
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})") logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) data = serialize_object(instance, resolve_tags=False)
elif key in self.queue: self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
# Object has already been created/updated at least once return
# Ignore pre_* many-to-many actions
if 'action' in kwargs and kwargs['action'] not in ('post_add', 'post_remove', 'post_clear'):
return
# Object has already been created/updated in the queue; update its queued representation
if key in self.queue:
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})") logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (self.queue[key][0], instance) data = serialize_object(instance, resolve_tags=False)
else: self.queue[key] = (self.queue[key][0], data)
# Modifying an existing object return
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance) # Modifying an existing object for the first time
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
data = serialize_object(instance, resolve_tags=False)
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
def pre_delete_handler(self, sender, instance, **kwargs): def pre_delete_handler(self, sender, instance, **kwargs):
""" """
@ -134,11 +143,12 @@ class checkout:
key = self.get_key_for_instance(instance) key = self.get_key_for_instance(instance)
object_type = instance._meta.verbose_name object_type = instance._meta.verbose_name
# Cancel the creation of a new object
if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE: if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE:
# Cancel the creation of a new object
logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})") logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})")
del self.queue[key] del self.queue[key]
else: return
# Delete an existing object
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})") # Delete an existing object
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance) logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, None)

View File

@ -3,6 +3,7 @@ from django.test import TransactionTestCase
from circuits.models import Provider, Circuit, CircuitType from circuits.models import Provider, Circuit, CircuitType
from extras.choices import ChangeActionChoices from extras.choices import ChangeActionChoices
from extras.models import Branch, Change, Tag from extras.models import Branch, Change, Tag
from ipam.models import ASN, RIR
from netbox.staging import checkout from netbox.staging import checkout
from utilities.testing import create_tags from utilities.testing import create_tags
@ -12,6 +13,14 @@ class StagingTestCase(TransactionTestCase):
def setUp(self): def setUp(self):
create_tags('Alpha', 'Bravo', 'Charlie') create_tags('Alpha', 'Bravo', 'Charlie')
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
asns = (
ASN(asn=65001, rir=rir),
ASN(asn=65002, rir=rir),
ASN(asn=65003, rir=rir),
)
ASN.objects.bulk_create(asns)
providers = ( providers = (
Provider(name='Provider A', slug='provider-a'), Provider(name='Provider A', slug='provider-a'),
Provider(name='Provider B', slug='provider-b'), Provider(name='Provider B', slug='provider-b'),
@ -36,14 +45,17 @@ class StagingTestCase(TransactionTestCase):
def test_object_creation(self): def test_object_creation(self):
branch = Branch.objects.create(name='Branch 1') branch = Branch.objects.create(name='Branch 1')
tags = Tag.objects.all() tags = Tag.objects.all()
asns = ASN.objects.all()
with checkout(branch): with checkout(branch):
provider = Provider.objects.create(name='Provider D', slug='provider-d') provider = Provider.objects.create(name='Provider D', slug='provider-d')
provider.asns.set(asns)
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first()) circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
circuit.tags.set(tags) circuit.tags.set(tags)
# Sanity-checking # Sanity-checking
self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Provider.objects.count(), 4)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 10) self.assertEqual(Circuit.objects.count(), 10)
self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertListEqual(list(circuit.tags.all()), list(tags))
@ -56,6 +68,8 @@ class StagingTestCase(TransactionTestCase):
with checkout(branch): with checkout(branch):
self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10) self.assertEqual(Circuit.objects.count(), 10)
provider = Provider.objects.get(name='Provider D')
self.assertListEqual(list(provider.asns.all()), list(asns))
circuit = Circuit.objects.get(cid='Circuit D1') circuit = Circuit.objects.get(cid='Circuit D1')
self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertListEqual(list(circuit.tags.all()), list(tags))
@ -63,6 +77,8 @@ class StagingTestCase(TransactionTestCase):
branch.merge() branch.merge()
self.assertEqual(Provider.objects.count(), 4) self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10) self.assertEqual(Circuit.objects.count(), 10)
provider = Provider.objects.get(name='Provider D')
self.assertListEqual(list(provider.asns.all()), list(asns))
circuit = Circuit.objects.get(cid='Circuit D1') circuit = Circuit.objects.get(cid='Circuit D1')
self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertListEqual(list(circuit.tags.all()), list(tags))
self.assertEqual(Change.objects.count(), 0) self.assertEqual(Change.objects.count(), 0)
@ -70,11 +86,13 @@ class StagingTestCase(TransactionTestCase):
def test_object_modification(self): def test_object_modification(self):
branch = Branch.objects.create(name='Branch 1') branch = Branch.objects.create(name='Branch 1')
tags = Tag.objects.all() tags = Tag.objects.all()
asns = ASN.objects.all()
with checkout(branch): with checkout(branch):
provider = Provider.objects.get(name='Provider A') provider = Provider.objects.get(name='Provider A')
provider.name = 'Provider X' provider.name = 'Provider X'
provider.save() provider.save()
provider.asns.set(asns)
circuit = Circuit.objects.get(cid='Circuit A1') circuit = Circuit.objects.get(cid='Circuit A1')
circuit.cid = 'Circuit X' circuit.cid = 'Circuit X'
circuit.save() circuit.save()
@ -83,6 +101,7 @@ class StagingTestCase(TransactionTestCase):
# Sanity-checking # Sanity-checking
self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X') self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
self.assertListEqual(list(circuit.tags.all()), list(tags)) self.assertListEqual(list(circuit.tags.all()), list(tags))
@ -90,6 +109,8 @@ class StagingTestCase(TransactionTestCase):
# Verify that changes have been rolled back after exiting the context # Verify that changes have been rolled back after exiting the context
self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A') self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
provider = Provider.objects.get(pk=provider.pk)
self.assertListEqual(list(provider.asns.all()), [])
self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(Circuit.objects.count(), 9)
circuit = Circuit.objects.get(pk=circuit.pk) circuit = Circuit.objects.get(pk=circuit.pk)
self.assertEqual(circuit.cid, 'Circuit A1') self.assertEqual(circuit.cid, 'Circuit A1')
@ -100,6 +121,8 @@ class StagingTestCase(TransactionTestCase):
with checkout(branch): with checkout(branch):
self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
provider = Provider.objects.get(pk=provider.pk)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(Circuit.objects.count(), 9)
circuit = Circuit.objects.get(pk=circuit.pk) circuit = Circuit.objects.get(pk=circuit.pk)
self.assertEqual(circuit.cid, 'Circuit X') self.assertEqual(circuit.cid, 'Circuit X')
@ -109,6 +132,8 @@ class StagingTestCase(TransactionTestCase):
branch.merge() branch.merge()
self.assertEqual(Provider.objects.count(), 3) self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
provider = Provider.objects.get(pk=provider.pk)
self.assertListEqual(list(provider.asns.all()), list(asns))
self.assertEqual(Circuit.objects.count(), 9) self.assertEqual(Circuit.objects.count(), 9)
circuit = Circuit.objects.get(pk=circuit.pk) circuit = Circuit.objects.get(pk=circuit.pk)
self.assertEqual(circuit.cid, 'Circuit X') self.assertEqual(circuit.cid, 'Circuit X')

View File

@ -136,7 +136,7 @@ def count_related(model, field):
return Coalesce(subquery, 0) return Coalesce(subquery, 0)
def serialize_object(obj, extra=None): def serialize_object(obj, resolve_tags=True, extra=None):
""" """
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
@ -155,8 +155,9 @@ def serialize_object(obj, extra=None):
if hasattr(obj, 'custom_field_data'): if hasattr(obj, 'custom_field_data'):
data['custom_fields'] = data.pop('custom_field_data') data['custom_fields'] = data.pop('custom_field_data')
# Include any tags. Check for tags cached on the instance; fall back to using the manager. # Resolve any assigned tags to their names. Check for tags cached on the instance;
if is_taggable(obj): # fall back to using the manager.
if resolve_tags and is_taggable(obj):
tags = getattr(obj, '_tags', None) or obj.tags.all() tags = getattr(obj, '_tags', None) or obj.tags.all()
data['tags'] = sorted([tag.name for tag in tags]) data['tags'] = sorted([tag.name for tag in tags])
@ -174,6 +175,10 @@ def serialize_object(obj, extra=None):
def deserialize_object(model, fields, pk=None): def deserialize_object(model, fields, pk=None):
"""
Instantiate an object from the given model and field data. Functions as
the complement to serialize_object().
"""
content_type = ContentType.objects.get_for_model(model) content_type = ContentType.objects.get_for_model(model)
if 'custom_fields' in fields: if 'custom_fields' in fields:
fields['custom_field_data'] = fields.pop('custom_fields') fields['custom_field_data'] = fields.pop('custom_fields')