This commit is contained in:
jeremystretch 2022-11-08 13:38:48 -05:00
parent 653acbf62c
commit c2e6853031
7 changed files with 402 additions and 2 deletions

View File

@ -182,3 +182,20 @@ class WebhookHttpMethodChoices(ChoiceSet):
(METHOD_PATCH, 'PATCH'),
(METHOD_DELETE, 'DELETE'),
)
#
# Staging
#
class ChangeActionChoices(ChoiceSet):
ACTION_CREATE = 'create'
ACTION_UPDATE = 'update'
ACTION_DELETE = 'delete'
CHOICES = (
(ACTION_CREATE, 'Created'),
(ACTION_UPDATE, 'Updated'),
(ACTION_DELETE, 'Deleted'),
)

View File

@ -0,0 +1,51 @@
# Generated by Django 4.1.2 on 2022-11-08 16:25
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('contenttypes', '0002_remove_content_type_name'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('extras', '0083_savedfilter'),
]
operations = [
migrations.CreateModel(
name='Branch',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False)),
('created', models.DateTimeField(auto_now_add=True, null=True)),
('last_updated', models.DateTimeField(auto_now=True, null=True)),
('name', models.CharField(max_length=100, unique=True)),
('description', models.CharField(blank=True, max_length=200)),
('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)),
],
options={
'ordering': ('name',),
},
),
migrations.CreateModel(
name='Change',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False)),
('created', models.DateTimeField(auto_now_add=True, null=True)),
('last_updated', models.DateTimeField(auto_now=True, null=True)),
('action', models.CharField(max_length=20)),
('object_id', models.PositiveBigIntegerField(blank=True, null=True)),
('data', models.JSONField(blank=True, null=True)),
('branch', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='changes', to='extras.branch')),
('object_type', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='contenttypes.contenttype')),
],
options={
'ordering': ('pk',),
},
),
migrations.AddConstraint(
model_name='change',
constraint=models.UniqueConstraint(fields=('branch', 'object_type', 'object_id'), name='extras_change_unique_branch_object'),
),
]

View File

@ -3,10 +3,13 @@ from .configcontexts import ConfigContext, ConfigContextModel
from .customfields import CustomField
from .models import *
from .search import *
from .staging import *
from .tags import Tag, TaggedItem
__all__ = (
'CachedValue',
'Change',
'Branch',
'ConfigContext',
'ConfigContextModel',
'ConfigRevision',

View File

@ -0,0 +1,101 @@
import logging
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models, transaction
from extras.choices import ChangeActionChoices
from netbox.models import ChangeLoggedModel
from utilities.utils import deserialize_object
__all__ = (
'Branch',
'Change',
)
logger = logging.getLogger('netbox.staging')
class Branch(ChangeLoggedModel):
name = models.CharField(
max_length=100,
unique=True
)
description = models.CharField(
max_length=200,
blank=True
)
user = models.ForeignKey(
to=get_user_model(),
on_delete=models.SET_NULL,
blank=True,
null=True
)
class Meta:
ordering = ('name',)
def merge(self):
logger.info(f'Merging changes in branch {self}')
with transaction.atomic():
for change in self.changes.all():
change.apply()
self.changes.all().delete()
class Change(ChangeLoggedModel):
branch = models.ForeignKey(
to=Branch,
on_delete=models.CASCADE,
related_name='changes'
)
action = models.CharField(
max_length=20,
choices=ChangeActionChoices
)
object_type = models.ForeignKey(
to=ContentType,
on_delete=models.CASCADE,
related_name='+'
)
object_id = models.PositiveBigIntegerField(
blank=True,
null=True
)
object = GenericForeignKey(
ct_field='object_type',
fk_field='object_id'
)
data = models.JSONField(
blank=True,
null=True
)
class Meta:
ordering = ('pk',)
constraints = (
models.UniqueConstraint(
fields=('branch', 'object_type', 'object_id'),
name='extras_change_unique_branch_object'
),
)
def apply(self):
model = self.object_type.model_class()
pk = self.object_id
if self.action == ChangeActionChoices.ACTION_CREATE:
instance = deserialize_object(model, self.data, pk=pk)
logger.info(f'Creating {model} {instance}')
instance.save()
if self.action == ChangeActionChoices.ACTION_UPDATE:
instance = deserialize_object(model, self.data, pk=pk)
logger.info(f'Updating {model} {instance}')
instance.save()
if self.action == ChangeActionChoices.ACTION_DELETE:
instance = model.objects.get(pk=self.object_id)
logger.info(f'Deleting {model} {instance}')
instance.delete()

113
netbox/netbox/staging.py Normal file
View File

@ -0,0 +1,113 @@
import logging
from contextlib import contextmanager
from django.contrib.contenttypes.models import ContentType
from django.db import transaction
from django.db.models.signals import pre_delete, post_save
from extras.choices import ChangeActionChoices
from extras.models import Change
from utilities.exceptions import AbortTransaction
from utilities.utils import serialize_object, shallow_compare_dict
logger = logging.getLogger('netbox.staging')
def get_changed_fields(instance):
model = instance._meta.model
original = model.objects.get(pk=instance.pk)
return shallow_compare_dict(
serialize_object(original),
serialize_object(instance),
exclude=('last_updated',)
)
def get_key_for_instance(instance):
object_type = ContentType.objects.get_for_model(instance)
return object_type, instance.pk
@contextmanager
def checkout(branch):
queue = {}
def save_handler(sender, instance, **kwargs):
return post_save_handler(sender, instance, branch=branch, queue=queue, **kwargs)
def delete_handler(sender, instance, **kwargs):
return pre_delete_handler(sender, instance, branch=branch, queue=queue, **kwargs)
# Connect signal handlers
post_save.connect(save_handler)
pre_delete.connect(delete_handler)
try:
with transaction.atomic():
yield
raise AbortTransaction()
# Roll back the transaction
except AbortTransaction:
pass
finally:
# Disconnect signal handlers
post_save.disconnect(save_handler)
pre_delete.disconnect(delete_handler)
# Process queued changes
logger.debug("Processing queued changes:")
for key, change in queue.items():
logger.debug(f' {key}: {change}')
# TODO: Optimize the creation of new Changes
changes = []
for key, change in queue.items():
object_type, pk = key
action, instance = change
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
data = serialize_object(instance)
else:
data = None
change = Change(
branch=branch,
action=action,
object_type=object_type,
object_id=pk,
data=data
)
changes.append(change)
Change.objects.bulk_create(changes)
def post_save_handler(sender, instance, branch, queue, created, **kwargs):
key = get_key_for_instance(instance)
if created:
# Creating a new object
logger.debug(f"Staging creation of {instance} under branch {branch}")
queue[key] = (ChangeActionChoices.ACTION_CREATE, instance)
elif key in queue:
# Object has already been created/updated at least once
logger.debug(f"Updating staged value for {instance} under branch {branch}")
queue[key] = (queue[key][0], instance)
else:
# Modifying an existing object
logger.debug(f"Staging changes to {instance} (PK: {instance.pk}) under branch {branch}")
queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
def pre_delete_handler(sender, instance, branch, queue, **kwargs):
key = get_key_for_instance(instance)
if key in queue and queue[key][0] == 'create':
# Cancel the creation of a new object
logger.debug(f"Removing staged deletion of {instance} (PK: {instance.pk}) under branch {branch}")
del queue[key]
else:
# Delete an existing object
logger.debug(f"Staging deletion of {instance} (PK: {instance.pk}) under branch {branch}")
queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)

View File

@ -0,0 +1,100 @@
from django.test import TestCase
from circuits.models import Provider, Circuit, CircuitType
from extras.models import Change, Branch
from netbox.staging import checkout
class StagingTestCase(TestCase):
@classmethod
def setUpTestData(cls):
providers = (
Provider(name='Provider A', slug='provider-a'),
Provider(name='Provider B', slug='provider-b'),
Provider(name='Provider C', slug='provider-c'),
)
Provider.objects.bulk_create(providers)
circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1')
Circuit.objects.bulk_create((
Circuit(provider=providers[0], cid='Circuit A1', type=circuit_type),
Circuit(provider=providers[0], cid='Circuit A2', type=circuit_type),
Circuit(provider=providers[0], cid='Circuit A3', type=circuit_type),
Circuit(provider=providers[1], cid='Circuit B1', type=circuit_type),
Circuit(provider=providers[1], cid='Circuit B2', type=circuit_type),
Circuit(provider=providers[1], cid='Circuit B3', type=circuit_type),
Circuit(provider=providers[2], cid='Circuit C1', type=circuit_type),
Circuit(provider=providers[2], cid='Circuit C2', type=circuit_type),
Circuit(provider=providers[2], cid='Circuit C3', type=circuit_type),
))
def test_object_creation(self):
branch = Branch.objects.create(name='Branch 1')
with checkout(branch):
provider = Provider.objects.create(name='Provider D', slug='provider-d')
Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
self.assertEqual(Provider.objects.count(), 4)
self.assertEqual(Circuit.objects.count(), 10)
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(Change.objects.count(), 2)
def test_object_modification(self):
branch = Branch.objects.create(name='Branch 1')
with checkout(branch):
provider = Provider.objects.get(name='Provider A')
provider.name = 'Provider X'
provider.save()
circuit = Circuit.objects.get(cid='Circuit A1')
circuit.cid = 'Circuit X'
circuit.save()
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit A1')
self.assertEqual(Change.objects.count(), 2)
def test_object_deletion(self):
branch = Branch.objects.create(name='Branch 1')
with checkout(branch):
provider = Provider.objects.get(name='Provider A')
provider.circuits.all().delete()
provider.delete()
self.assertEqual(Provider.objects.count(), 2)
self.assertEqual(Circuit.objects.count(), 6)
self.assertEqual(Provider.objects.count(), 3)
self.assertEqual(Circuit.objects.count(), 9)
self.assertEqual(Change.objects.count(), 4)
def test_create_update_delete_clean(self):
branch = Branch.objects.create(name='Branch 1')
with checkout(branch):
# Create a new object
provider = Provider.objects.create(name='Provider D', slug='provider-d')
provider.save()
# Update it
provider.comments = 'Another change'
provider.save()
# Delete it
provider.delete()
self.assertFalse(Change.objects.exists())

View File

@ -6,7 +6,8 @@ from decimal import Decimal
from itertools import count, groupby
import bleach
from django.core.serializers import serialize
from django.contrib.contenttypes.models import ContentType
from django.core import serializers
from django.db.models import Count, OuterRef, Subquery
from django.db.models.functions import Coalesce
from django.http import QueryDict
@ -142,7 +143,7 @@ def serialize_object(obj, extra=None):
can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
implicitly excluded.
"""
json_str = serialize('json', [obj])
json_str = serializers.serialize('json', [obj])
data = json.loads(json_str)[0]['fields']
# Exclude any MPTTModel fields
@ -172,6 +173,20 @@ def serialize_object(obj, extra=None):
return data
def deserialize_object(model, fields, pk=None):
content_type = ContentType.objects.get_for_model(model)
if 'custom_fields' in fields:
fields['custom_field_data'] = fields.pop('custom_fields')
data = {
'model': '.'.join(content_type.natural_key()),
'pk': pk,
'fields': fields,
}
instance = list(serializers.deserialize('python', [data]))[0]
return instance
def dict_to_filter_params(d, prefix=''):
"""
Translate a dictionary of attributes to a nested set of parameters suitable for QuerySet filtering. For example: