mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-26 01:06:11 -06:00
Incorporate M2M changes
This commit is contained in:
parent
f4cec1e6c8
commit
b57616e9ce
@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models.signals import pre_delete, post_save
|
from django.db.models.signals import m2m_changed, pre_delete, post_save
|
||||||
|
|
||||||
from extras.choices import ChangeActionChoices
|
from extras.choices import ChangeActionChoices
|
||||||
from extras.models import Change
|
from extras.models import Change
|
||||||
@ -47,6 +47,7 @@ class checkout:
|
|||||||
# Connect signal handlers
|
# Connect signal handlers
|
||||||
logger.debug("Connecting signal handlers")
|
logger.debug("Connecting signal handlers")
|
||||||
post_save.connect(self.post_save_handler)
|
post_save.connect(self.post_save_handler)
|
||||||
|
m2m_changed.connect(self.post_save_handler)
|
||||||
pre_delete.connect(self.pre_delete_handler)
|
pre_delete.connect(self.pre_delete_handler)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
@ -54,6 +55,7 @@ class checkout:
|
|||||||
# Disconnect signal handlers
|
# Disconnect signal handlers
|
||||||
logger.debug("Disconnecting signal handlers")
|
logger.debug("Disconnecting signal handlers")
|
||||||
post_save.disconnect(self.post_save_handler)
|
post_save.disconnect(self.post_save_handler)
|
||||||
|
m2m_changed.disconnect(self.post_save_handler)
|
||||||
pre_delete.disconnect(self.pre_delete_handler)
|
pre_delete.disconnect(self.pre_delete_handler)
|
||||||
|
|
||||||
# Roll back the transaction to return the database to its original state
|
# Roll back the transaction to return the database to its original state
|
||||||
@ -87,10 +89,7 @@ class checkout:
|
|||||||
for key, change in self.queue.items():
|
for key, change in self.queue.items():
|
||||||
logger.debug(f' {key}: {change}')
|
logger.debug(f' {key}: {change}')
|
||||||
object_type, pk = key
|
object_type, pk = key
|
||||||
action, instance = change
|
action, data = change
|
||||||
data = None
|
|
||||||
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
|
|
||||||
data = serialize_object(instance)
|
|
||||||
|
|
||||||
changes.append(Change(
|
changes.append(Change(
|
||||||
branch=self.branch,
|
branch=self.branch,
|
||||||
@ -107,25 +106,35 @@ class checkout:
|
|||||||
# Signal handlers
|
# Signal handlers
|
||||||
#
|
#
|
||||||
|
|
||||||
def post_save_handler(self, sender, instance, created, **kwargs):
|
def post_save_handler(self, sender, instance, **kwargs):
|
||||||
"""
|
"""
|
||||||
Hooks to the post_save signal when a branch is active to queue create and update actions.
|
Hooks to the post_save signal when a branch is active to queue create and update actions.
|
||||||
"""
|
"""
|
||||||
key = self.get_key_for_instance(instance)
|
key = self.get_key_for_instance(instance)
|
||||||
object_type = instance._meta.verbose_name
|
object_type = instance._meta.verbose_name
|
||||||
|
|
||||||
if created:
|
# Creating a new object
|
||||||
# Creating a new object
|
if kwargs.get('created'):
|
||||||
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
|
logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
|
||||||
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance)
|
data = serialize_object(instance, resolve_tags=False)
|
||||||
elif key in self.queue:
|
self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
|
||||||
# Object has already been created/updated at least once
|
return
|
||||||
|
|
||||||
|
# Ignore pre_* many-to-many actions
|
||||||
|
if 'action' in kwargs and kwargs['action'] not in ('post_add', 'post_remove', 'post_clear'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Object has already been created/updated in the queue; update its queued representation
|
||||||
|
if key in self.queue:
|
||||||
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
|
logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
|
||||||
self.queue[key] = (self.queue[key][0], instance)
|
data = serialize_object(instance, resolve_tags=False)
|
||||||
else:
|
self.queue[key] = (self.queue[key][0], data)
|
||||||
# Modifying an existing object
|
return
|
||||||
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
|
|
||||||
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
|
# Modifying an existing object for the first time
|
||||||
|
logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
|
||||||
|
data = serialize_object(instance, resolve_tags=False)
|
||||||
|
self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
|
||||||
|
|
||||||
def pre_delete_handler(self, sender, instance, **kwargs):
|
def pre_delete_handler(self, sender, instance, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -134,11 +143,12 @@ class checkout:
|
|||||||
key = self.get_key_for_instance(instance)
|
key = self.get_key_for_instance(instance)
|
||||||
object_type = instance._meta.verbose_name
|
object_type = instance._meta.verbose_name
|
||||||
|
|
||||||
|
# Cancel the creation of a new object
|
||||||
if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE:
|
if key in self.queue and self.queue[key][0] == ChangeActionChoices.ACTION_CREATE:
|
||||||
# Cancel the creation of a new object
|
|
||||||
logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})")
|
logger.debug(f"[{self.branch}] Removing staged creation of {object_type} {instance} (PK: {instance.pk})")
|
||||||
del self.queue[key]
|
del self.queue[key]
|
||||||
else:
|
return
|
||||||
# Delete an existing object
|
|
||||||
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
|
# Delete an existing object
|
||||||
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)
|
logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
|
||||||
|
self.queue[key] = (ChangeActionChoices.ACTION_DELETE, None)
|
||||||
|
@ -3,6 +3,7 @@ from django.test import TransactionTestCase
|
|||||||
from circuits.models import Provider, Circuit, CircuitType
|
from circuits.models import Provider, Circuit, CircuitType
|
||||||
from extras.choices import ChangeActionChoices
|
from extras.choices import ChangeActionChoices
|
||||||
from extras.models import Branch, Change, Tag
|
from extras.models import Branch, Change, Tag
|
||||||
|
from ipam.models import ASN, RIR
|
||||||
from netbox.staging import checkout
|
from netbox.staging import checkout
|
||||||
from utilities.testing import create_tags
|
from utilities.testing import create_tags
|
||||||
|
|
||||||
@ -12,6 +13,14 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
create_tags('Alpha', 'Bravo', 'Charlie')
|
create_tags('Alpha', 'Bravo', 'Charlie')
|
||||||
|
|
||||||
|
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
|
||||||
|
asns = (
|
||||||
|
ASN(asn=65001, rir=rir),
|
||||||
|
ASN(asn=65002, rir=rir),
|
||||||
|
ASN(asn=65003, rir=rir),
|
||||||
|
)
|
||||||
|
ASN.objects.bulk_create(asns)
|
||||||
|
|
||||||
providers = (
|
providers = (
|
||||||
Provider(name='Provider A', slug='provider-a'),
|
Provider(name='Provider A', slug='provider-a'),
|
||||||
Provider(name='Provider B', slug='provider-b'),
|
Provider(name='Provider B', slug='provider-b'),
|
||||||
@ -36,14 +45,17 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
def test_object_creation(self):
|
def test_object_creation(self):
|
||||||
branch = Branch.objects.create(name='Branch 1')
|
branch = Branch.objects.create(name='Branch 1')
|
||||||
tags = Tag.objects.all()
|
tags = Tag.objects.all()
|
||||||
|
asns = ASN.objects.all()
|
||||||
|
|
||||||
with checkout(branch):
|
with checkout(branch):
|
||||||
provider = Provider.objects.create(name='Provider D', slug='provider-d')
|
provider = Provider.objects.create(name='Provider D', slug='provider-d')
|
||||||
|
provider.asns.set(asns)
|
||||||
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
|
circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
|
||||||
circuit.tags.set(tags)
|
circuit.tags.set(tags)
|
||||||
|
|
||||||
# Sanity-checking
|
# Sanity-checking
|
||||||
self.assertEqual(Provider.objects.count(), 4)
|
self.assertEqual(Provider.objects.count(), 4)
|
||||||
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
self.assertEqual(Circuit.objects.count(), 10)
|
self.assertEqual(Circuit.objects.count(), 10)
|
||||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||||
|
|
||||||
@ -56,6 +68,8 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
with checkout(branch):
|
with checkout(branch):
|
||||||
self.assertEqual(Provider.objects.count(), 4)
|
self.assertEqual(Provider.objects.count(), 4)
|
||||||
self.assertEqual(Circuit.objects.count(), 10)
|
self.assertEqual(Circuit.objects.count(), 10)
|
||||||
|
provider = Provider.objects.get(name='Provider D')
|
||||||
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
circuit = Circuit.objects.get(cid='Circuit D1')
|
circuit = Circuit.objects.get(cid='Circuit D1')
|
||||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||||
|
|
||||||
@ -63,6 +77,8 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
branch.merge()
|
branch.merge()
|
||||||
self.assertEqual(Provider.objects.count(), 4)
|
self.assertEqual(Provider.objects.count(), 4)
|
||||||
self.assertEqual(Circuit.objects.count(), 10)
|
self.assertEqual(Circuit.objects.count(), 10)
|
||||||
|
provider = Provider.objects.get(name='Provider D')
|
||||||
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
circuit = Circuit.objects.get(cid='Circuit D1')
|
circuit = Circuit.objects.get(cid='Circuit D1')
|
||||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||||
self.assertEqual(Change.objects.count(), 0)
|
self.assertEqual(Change.objects.count(), 0)
|
||||||
@ -70,11 +86,13 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
def test_object_modification(self):
|
def test_object_modification(self):
|
||||||
branch = Branch.objects.create(name='Branch 1')
|
branch = Branch.objects.create(name='Branch 1')
|
||||||
tags = Tag.objects.all()
|
tags = Tag.objects.all()
|
||||||
|
asns = ASN.objects.all()
|
||||||
|
|
||||||
with checkout(branch):
|
with checkout(branch):
|
||||||
provider = Provider.objects.get(name='Provider A')
|
provider = Provider.objects.get(name='Provider A')
|
||||||
provider.name = 'Provider X'
|
provider.name = 'Provider X'
|
||||||
provider.save()
|
provider.save()
|
||||||
|
provider.asns.set(asns)
|
||||||
circuit = Circuit.objects.get(cid='Circuit A1')
|
circuit = Circuit.objects.get(cid='Circuit A1')
|
||||||
circuit.cid = 'Circuit X'
|
circuit.cid = 'Circuit X'
|
||||||
circuit.save()
|
circuit.save()
|
||||||
@ -83,6 +101,7 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
# Sanity-checking
|
# Sanity-checking
|
||||||
self.assertEqual(Provider.objects.count(), 3)
|
self.assertEqual(Provider.objects.count(), 3)
|
||||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
||||||
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
self.assertEqual(Circuit.objects.count(), 9)
|
self.assertEqual(Circuit.objects.count(), 9)
|
||||||
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
|
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
|
||||||
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
self.assertListEqual(list(circuit.tags.all()), list(tags))
|
||||||
@ -90,6 +109,8 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
# Verify that changes have been rolled back after exiting the context
|
# Verify that changes have been rolled back after exiting the context
|
||||||
self.assertEqual(Provider.objects.count(), 3)
|
self.assertEqual(Provider.objects.count(), 3)
|
||||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
|
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
|
||||||
|
provider = Provider.objects.get(pk=provider.pk)
|
||||||
|
self.assertListEqual(list(provider.asns.all()), [])
|
||||||
self.assertEqual(Circuit.objects.count(), 9)
|
self.assertEqual(Circuit.objects.count(), 9)
|
||||||
circuit = Circuit.objects.get(pk=circuit.pk)
|
circuit = Circuit.objects.get(pk=circuit.pk)
|
||||||
self.assertEqual(circuit.cid, 'Circuit A1')
|
self.assertEqual(circuit.cid, 'Circuit A1')
|
||||||
@ -100,6 +121,8 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
with checkout(branch):
|
with checkout(branch):
|
||||||
self.assertEqual(Provider.objects.count(), 3)
|
self.assertEqual(Provider.objects.count(), 3)
|
||||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
||||||
|
provider = Provider.objects.get(pk=provider.pk)
|
||||||
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
self.assertEqual(Circuit.objects.count(), 9)
|
self.assertEqual(Circuit.objects.count(), 9)
|
||||||
circuit = Circuit.objects.get(pk=circuit.pk)
|
circuit = Circuit.objects.get(pk=circuit.pk)
|
||||||
self.assertEqual(circuit.cid, 'Circuit X')
|
self.assertEqual(circuit.cid, 'Circuit X')
|
||||||
@ -109,6 +132,8 @@ class StagingTestCase(TransactionTestCase):
|
|||||||
branch.merge()
|
branch.merge()
|
||||||
self.assertEqual(Provider.objects.count(), 3)
|
self.assertEqual(Provider.objects.count(), 3)
|
||||||
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
|
||||||
|
provider = Provider.objects.get(pk=provider.pk)
|
||||||
|
self.assertListEqual(list(provider.asns.all()), list(asns))
|
||||||
self.assertEqual(Circuit.objects.count(), 9)
|
self.assertEqual(Circuit.objects.count(), 9)
|
||||||
circuit = Circuit.objects.get(pk=circuit.pk)
|
circuit = Circuit.objects.get(pk=circuit.pk)
|
||||||
self.assertEqual(circuit.cid, 'Circuit X')
|
self.assertEqual(circuit.cid, 'Circuit X')
|
||||||
|
@ -136,7 +136,7 @@ def count_related(model, field):
|
|||||||
return Coalesce(subquery, 0)
|
return Coalesce(subquery, 0)
|
||||||
|
|
||||||
|
|
||||||
def serialize_object(obj, extra=None):
|
def serialize_object(obj, resolve_tags=True, extra=None):
|
||||||
"""
|
"""
|
||||||
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
|
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
|
||||||
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
|
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
|
||||||
@ -155,8 +155,9 @@ def serialize_object(obj, extra=None):
|
|||||||
if hasattr(obj, 'custom_field_data'):
|
if hasattr(obj, 'custom_field_data'):
|
||||||
data['custom_fields'] = data.pop('custom_field_data')
|
data['custom_fields'] = data.pop('custom_field_data')
|
||||||
|
|
||||||
# Include any tags. Check for tags cached on the instance; fall back to using the manager.
|
# Resolve any assigned tags to their names. Check for tags cached on the instance;
|
||||||
if is_taggable(obj):
|
# fall back to using the manager.
|
||||||
|
if resolve_tags and is_taggable(obj):
|
||||||
tags = getattr(obj, '_tags', None) or obj.tags.all()
|
tags = getattr(obj, '_tags', None) or obj.tags.all()
|
||||||
data['tags'] = sorted([tag.name for tag in tags])
|
data['tags'] = sorted([tag.name for tag in tags])
|
||||||
|
|
||||||
@ -174,6 +175,10 @@ def serialize_object(obj, extra=None):
|
|||||||
|
|
||||||
|
|
||||||
def deserialize_object(model, fields, pk=None):
|
def deserialize_object(model, fields, pk=None):
|
||||||
|
"""
|
||||||
|
Instantiate an object from the given model and field data. Functions as
|
||||||
|
the complement to serialize_object().
|
||||||
|
"""
|
||||||
content_type = ContentType.objects.get_for_model(model)
|
content_type = ContentType.objects.get_for_model(model)
|
||||||
if 'custom_fields' in fields:
|
if 'custom_fields' in fields:
|
||||||
fields['custom_field_data'] = fields.pop('custom_fields')
|
fields['custom_field_data'] = fields.pop('custom_fields')
|
||||||
|
Loading…
Reference in New Issue
Block a user