mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-25 08:46:10 -06:00
Incorporate M2M changes
This commit is contained in:
parent
f4cec1e6c8
commit
b57616e9ce
@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import transaction
|
||||
from django.db.models.signals import pre_delete, post_save
|
||||
from django.db.models.signals import m2m_changed, pre_delete, post_save
|
||||
|
||||
from extras.choices import ChangeActionChoices
|
||||
from extras.models import Change
|
||||
@ -47,6 +47,7 @@ class checkout:
|
||||
# Connect signal handlers
|
||||
logger.debug("Connecting signal handlers")
|
||||
post_save.connect(self.post_save_handler)
|
||||
m2m_changed.connect(self.post_save_handler)
|
||||
pre_delete.connect(self.pre_delete_handler)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -54,6 +55,7 @@ class checkout:
|
||||
# Disconnect signal handlers
|
||||
logger.debug("Disconnecting signal handlers")
|
||||
post_save.disconnect(self.post_save_handler)
|
||||
m2m_changed.disconnect(self.post_save_handler)
|
||||
pre_delete.disconnect(self.pre_delete_handler)
|
||||
|
||||
# Roll back the transaction to return the database to its original state
|
||||
@ -87,10 +89,7 @@ class checkout:
|
||||
for key, change in self.queue.items():
|
||||
logger.debug(f' {key}: {change}')
|
||||
object_type, pk = key
|
||||
action, instance = change
|
||||
data = None
|
||||
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
|
||||
data = serialize_object(instance)
|
||||
action, data = change
|
||||
|
||||
changes.append(Change(
|
||||
branch=self.branch,
|
||||
@ -107,25 +106,35 @@ class checkout:
|
||||
# Signal handlers
|
||||
#
|
||||
|
||||
def post_save_handler(self, sender, instance, created, **kwargs):
|
||||
def post_save_handler(self, sender, instance, **kwargs):
|
||||
"""
|
||||
Hooks to the post_save signal when a branch is active to queue create and update actions.
|
||||
"""
|
||||
key = self.get_key_for_instance(instance)
|
||||
object_type = instance._meta.verbose_name
|
||||
|
||||
if created:
|
||||
# Creating a new object
|
||||
# Creating a new object
|
||||
if kwargs.get('created'):
|
||||
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
|
||||
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance)
|
||||
elif key in self.queue:
|
||||
# Object has already been created/updated at least once
|
||||
data = serialize_object(instance, resolve_tags=False)
|
||||
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
|
||||
return
|
||||
|
||||
# Ignore pre_* many-to-many actions
|
||||
if 'action' in kwargs and kwargs['action'] not in ('post_add', 'post_remove', 'post_clear'):
|
||||
return
|
||||
|
||||
# Object has already been created/updated in the queue; update its queued representation
|
||||
if key in self.queue:
|
||||
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
|
||||
self.queue[key] = (self.queue[key][0], instance)
|
||||
else:
|
||||
# Modifying an existing object
|
||||
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
|
||||
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
|
||||
data = serialize_object(instance, resolve_tags=False)
|
||||
self.queue[key] = (self.queue[key][0], data)
|
||||
return
|
||||
|
||||
# Modifying an existing object for the first time
|
||||
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
|
||||
data = serialize_object(instance, resolve_tags=False)
|
||||
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
|
||||
|
||||
def pre_delete_handler(self, sender, instance, **kwargs):
|
||||
"""
|
||||
@ -134,11 +143,12 @@ class checkout:
|
||||
key = self.get_key_for_instance(instance)
|
||||
object_type = instance._meta.verbose_name
|
||||
|
||||
# Cancel the creation of a new object
|
||||
if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE:
|
||||
# Cancel the creation of a new object
|
||||
logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})")
|
||||
del self.queue[key]
|
||||
else:
|
||||
# Delete an existing object
|
||||
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
|
||||
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)
|
||||
return
|
||||
|
||||
# Delete an existing object
|
||||
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
|
||||
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, None)
|
||||
|
@ -3,6 +3,7 @@ from django.test import TransactionTestCase
|
||||
from circuits.models import Provider, Circuit, CircuitType
|
||||
from extras.choices import ChangeActionChoices
|
||||
from extras.models import Branch, Change, Tag
|
||||
from ipam.models import ASN, RIR
|
||||
from netbox.staging import checkout
|
||||
from utilities.testing import create_tags
|
||||
|
||||
@ -12,6 +13,14 @@ class StagingTestCase(TransactionTestCase):
|
||||
def setUp(self):
|
||||
create_tags('Alpha', 'Bravo', 'Charlie')
|
||||
|
||||
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
|
||||
asns = (
|
||||
ASN(asn=65001, rir=rir),
|
||||
ASN(asn=65002, rir=rir),
|
||||
ASN(asn=65003, rir=rir),
|
||||
)
|
||||
ASN.objects.bulk_create(asns)
|
||||
|
||||
providers = (
|
||||
Provider(name='Provider A', slug='provider-a'),
|
||||
Provider(name='Provider B', slug='provider-b'),
|
||||
@ -36,14 +45,17 @@ class StagingTestCase(TransactionTestCase):
|
||||
def test_object_creation(self):
|
||||
branch = Branch.objects.create(name='Branch 1')
|
||||
tags = Tag.objects.all()
|
||||
asns = ASN.objects.all()
|
||||
|
||||
with checkout(branch):
|
||||
provider = Provider.objects.create(name='Provider D', slug='provider-d')
|
||||
provider.asns.set(asns)
|
||||
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
|
||||
circuit.tags.set(tags)
|
||||
|
||||
# Sanity-checking
|
||||
self.assertEqual(Provider.objects.count(), 4)
|
||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||
self.assertEqual(Circuit.objects.count(), 10)
|
||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||
|
||||
@ -56,6 +68,8 @@ class StagingTestCase(TransactionTestCase):
|
||||
with checkout(branch):
|
||||
self.assertEqual(Provider.objects.count(), 4)
|
||||
self.assertEqual(Circuit.objects.count(), 10)
|
||||
provider = Provider.objects.get(name='Provider D')
|
||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||
circuit = Circuit.objects.get(cid='Circuit D1')
|
||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||
|
||||
@ -63,6 +77,8 @@ class StagingTestCase(TransactionTestCase):
|
||||
branch.merge()
|
||||
self.assertEqual(Provider.objects.count(), 4)
|
||||
self.assertEqual(Circuit.objects.count(), 10)
|
||||
provider = Provider.objects.get(name='Provider D')
|
||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||
circuit = Circuit.objects.get(cid='Circuit D1')
|
||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||
self.assertEqual(Change.objects.count(), 0)
|
||||
@ -70,11 +86,13 @@ class StagingTestCase(TransactionTestCase):
|
||||
def test_object_modification(self):
|
||||
branch = Branch.objects.create(name='Branch 1')
|
||||
tags = Tag.objects.all()
|
||||
asns = ASN.objects.all()
|
||||
|
||||
with checkout(branch):
|
||||
provider = Provider.objects.get(name='Provider A')
|
||||
provider.name = 'Provider X'
|
||||
provider.save()
|
||||
provider.asns.set(asns)
|
||||
circuit = Circuit.objects.get(cid='Circuit A1')
|
||||
circuit.cid = 'Circuit X'
|
||||
circuit.save()
|
||||
@ -83,6 +101,7 @@ class StagingTestCase(TransactionTestCase):
|
||||
# Sanity-checking
|
||||
self.assertEqual(Provider.objects.count(), 3)
|
||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||
self.assertEqual(Circuit.objects.count(), 9)
|
||||
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
|
||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||
@ -90,6 +109,8 @@ class StagingTestCase(TransactionTestCase):
|
||||
# Verify that changes have been rolled back after exiting the context
|
||||
self.assertEqual(Provider.objects.count(), 3)
|
||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
|
||||
provider = Provider.objects.get(pk=provider.pk)
|
||||
self.assertListEqual(list(provider.asns.all()), [])
|
||||
self.assertEqual(Circuit.objects.count(), 9)
|
||||
circuit = Circuit.objects.get(pk=circuit.pk)
|
||||
self.assertEqual(circuit.cid, 'Circuit A1')
|
||||
@ -100,6 +121,8 @@ class StagingTestCase(TransactionTestCase):
|
||||
with checkout(branch):
|
||||
self.assertEqual(Provider.objects.count(), 3)
|
||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
||||
provider = Provider.objects.get(pk=provider.pk)
|
||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||
self.assertEqual(Circuit.objects.count(), 9)
|
||||
circuit = Circuit.objects.get(pk=circuit.pk)
|
||||
self.assertEqual(circuit.cid, 'Circuit X')
|
||||
@ -109,6 +132,8 @@ class StagingTestCase(TransactionTestCase):
|
||||
branch.merge()
|
||||
self.assertEqual(Provider.objects.count(), 3)
|
||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
||||
provider = Provider.objects.get(pk=provider.pk)
|
||||
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||
self.assertEqual(Circuit.objects.count(), 9)
|
||||
circuit = Circuit.objects.get(pk=circuit.pk)
|
||||
self.assertEqual(circuit.cid, 'Circuit X')
|
||||
|
@ -136,7 +136,7 @@ def count_related(model, field):
|
||||
return Coalesce(subquery, 0)
|
||||
|
||||
|
||||
def serialize_object(obj, extra=None):
|
||||
def serialize_object(obj, resolve_tags=True, extra=None):
|
||||
"""
|
||||
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
|
||||
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
|
||||
@ -155,8 +155,9 @@ def serialize_object(obj, extra=None):
|
||||
if hasattr(obj, 'custom_field_data'):
|
||||
data['custom_fields'] = data.pop('custom_field_data')
|
||||
|
||||
# Include any tags. Check for tags cached on the instance; fall back to using the manager.
|
||||
if is_taggable(obj):
|
||||
# Resolve any assigned tags to their names. Check for tags cached on the instance;
|
||||
# fall back to using the manager.
|
||||
if resolve_tags and is_taggable(obj):
|
||||
tags = getattr(obj, '_tags', None) or obj.tags.all()
|
||||
data['tags'] = sorted([tag.name for tag in tags])
|
||||
|
||||
@ -174,6 +175,10 @@ def serialize_object(obj, extra=None):
|
||||
|
||||
|
||||
def deserialize_object(model, fields, pk=None):
|
||||
"""
|
||||
Instantiate an object from the given model and field data. Functions as
|
||||
the complement to serialize_object().
|
||||
"""
|
||||
content_type = ContentType.objects.get_for_model(model)
|
||||
if 'custom_fields' in fields:
|
||||
fields['custom_field_data'] = fields.pop('custom_fields')
|
||||
|
Loading…
Reference in New Issue
Block a user