From c2e6853031d373cb6bf74c6883c0a4b45993f5f8 Mon Sep 17 00:00:00 2001 From: jeremystretch Date: Tue, 8 Nov 2022 13:38:48 -0500 Subject: [PATCH] WIP --- netbox/extras/choices.py | 17 ++++ netbox/extras/migrations/0084_staging.py | 51 ++++++++++ netbox/extras/models/__init__.py | 3 + netbox/extras/models/staging.py | 101 ++++++++++++++++++++ netbox/netbox/staging.py | 113 +++++++++++++++++++++++ netbox/netbox/tests/test_staging.py | 100 ++++++++++++++++++++ netbox/utilities/utils.py | 19 +++- 7 files changed, 402 insertions(+), 2 deletions(-) create mode 100644 netbox/extras/migrations/0084_staging.py create mode 100644 netbox/extras/models/staging.py create mode 100644 netbox/netbox/staging.py create mode 100644 netbox/netbox/tests/test_staging.py diff --git a/netbox/extras/choices.py b/netbox/extras/choices.py index ee806f094..73b5648aa 100644 --- a/netbox/extras/choices.py +++ b/netbox/extras/choices.py @@ -182,3 +182,20 @@ class WebhookHttpMethodChoices(ChoiceSet): (METHOD_PATCH, 'PATCH'), (METHOD_DELETE, 'DELETE'), ) + + +# +# Staging +# + +class ChangeActionChoices(ChoiceSet): + + ACTION_CREATE = 'create' + ACTION_UPDATE = 'update' + ACTION_DELETE = 'delete' + + CHOICES = ( + (ACTION_CREATE, 'Created'), + (ACTION_UPDATE, 'Updated'), + (ACTION_DELETE, 'Deleted'), + ) diff --git a/netbox/extras/migrations/0084_staging.py b/netbox/extras/migrations/0084_staging.py new file mode 100644 index 000000000..84ced1e93 --- /dev/null +++ b/netbox/extras/migrations/0084_staging.py @@ -0,0 +1,51 @@ +# Generated by Django 4.1.2 on 2022-11-08 16:25 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('extras', '0083_savedfilter'), + ] + + operations = [ + migrations.CreateModel( + name='Branch', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False)), + ('created', models.DateTimeField(auto_now_add=True, null=True)), + ('last_updated', models.DateTimeField(auto_now=True, null=True)), + ('name', models.CharField(max_length=100, unique=True)), + ('description', models.CharField(blank=True, max_length=200)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ('name',), + }, + ), + migrations.CreateModel( + name='Change', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False)), + ('created', models.DateTimeField(auto_now_add=True, null=True)), + ('last_updated', models.DateTimeField(auto_now=True, null=True)), + ('action', models.CharField(max_length=20)), + ('object_id', models.PositiveBigIntegerField(blank=True, null=True)), + ('data', models.JSONField(blank=True, null=True)), + ('branch', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='changes', to='extras.branch')), + ('object_type', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='contenttypes.contenttype')), + ], + options={ + 'ordering': ('pk',), + }, + ), + migrations.AddConstraint( + model_name='change', + constraint=models.UniqueConstraint(fields=('branch', 'object_type', 'object_id'), name='extras_change_unique_branch_object'), + ), + ] diff --git a/netbox/extras/models/__init__.py b/netbox/extras/models/__init__.py index 6d2bf288c..e992822f4 100644 --- a/netbox/extras/models/__init__.py +++ b/netbox/extras/models/__init__.py @@ -3,10 +3,13 @@ from .configcontexts import ConfigContext, ConfigContextModel from .customfields import CustomField from .models import * from .search import * +from .staging import * from .tags import Tag, TaggedItem __all__ = ( 'CachedValue', + 'Change', + 'Branch', 'ConfigContext', 'ConfigContextModel', 'ConfigRevision', diff --git a/netbox/extras/models/staging.py b/netbox/extras/models/staging.py new file mode 100644 index 000000000..b5564f428 --- /dev/null +++ b/netbox/extras/models/staging.py @@ -0,0 +1,101 @@ +import logging + +from django.contrib.auth import get_user_model +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.db import models, transaction + +from extras.choices import ChangeActionChoices +from netbox.models import ChangeLoggedModel +from utilities.utils import deserialize_object + +__all__ = ( + 'Branch', + 'Change', +) + +logger = logging.getLogger('netbox.staging') + + +class Branch(ChangeLoggedModel): + name = models.CharField( + max_length=100, + unique=True + ) + description = models.CharField( + max_length=200, + blank=True + ) + user = models.ForeignKey( + to=get_user_model(), + on_delete=models.SET_NULL, + blank=True, + null=True + ) + + class Meta: + ordering = ('name',) + + def merge(self): + logger.info(f'Merging changes in branch {self}') + with transaction.atomic(): + for change in self.changes.all(): + change.apply() + self.changes.all().delete() + + +class Change(ChangeLoggedModel): + branch = models.ForeignKey( + to=Branch, + on_delete=models.CASCADE, + related_name='changes' + ) + action = models.CharField( + max_length=20, + choices=ChangeActionChoices + ) + object_type = models.ForeignKey( + to=ContentType, + on_delete=models.CASCADE, + related_name='+' + ) + object_id = models.PositiveBigIntegerField( + blank=True, + null=True + ) + object = GenericForeignKey( + ct_field='object_type', + fk_field='object_id' + ) + data = models.JSONField( + blank=True, + null=True + ) + + class Meta: + ordering = ('pk',) + constraints = ( + models.UniqueConstraint( + fields=('branch', 'object_type', 'object_id'), + name='extras_change_unique_branch_object' + ), + ) + + def apply(self): + model = self.object_type.model_class() + pk = self.object_id + + if self.action == ChangeActionChoices.ACTION_CREATE: + instance = deserialize_object(model, self.data, pk=pk) + logger.info(f'Creating {model} {instance}') + instance.save() + + if self.action == ChangeActionChoices.ACTION_UPDATE: + instance = deserialize_object(model, self.data, pk=pk) + logger.info(f'Updating {model} {instance}') + instance.save() + + if self.action == ChangeActionChoices.ACTION_DELETE: + instance = model.objects.get(pk=self.object_id) + logger.info(f'Deleting {model} {instance}') + instance.delete() diff --git a/netbox/netbox/staging.py b/netbox/netbox/staging.py new file mode 100644 index 000000000..782a7fb40 --- /dev/null +++ b/netbox/netbox/staging.py @@ -0,0 +1,113 @@ +import logging +from contextlib import contextmanager + +from django.contrib.contenttypes.models import ContentType +from django.db import transaction +from django.db.models.signals import pre_delete, post_save + +from extras.choices import ChangeActionChoices +from extras.models import Change +from utilities.exceptions import AbortTransaction +from utilities.utils import serialize_object, shallow_compare_dict + +logger = logging.getLogger('netbox.staging') + + +def get_changed_fields(instance): + model = instance._meta.model + original = model.objects.get(pk=instance.pk) + return shallow_compare_dict( + serialize_object(original), + serialize_object(instance), + exclude=('last_updated',) + ) + + +def get_key_for_instance(instance): + object_type = ContentType.objects.get_for_model(instance) + return object_type, instance.pk + + +@contextmanager +def checkout(branch): + + queue = {} + + def save_handler(sender, instance, **kwargs): + return post_save_handler(sender, instance, branch=branch, queue=queue, **kwargs) + + def delete_handler(sender, instance, **kwargs): + return pre_delete_handler(sender, instance, branch=branch, queue=queue, **kwargs) + + # Connect signal handlers + post_save.connect(save_handler) + pre_delete.connect(delete_handler) + + try: + with transaction.atomic(): + yield + raise AbortTransaction() + + # Roll back the transaction + except AbortTransaction: + pass + + finally: + + # Disconnect signal handlers + post_save.disconnect(save_handler) + pre_delete.disconnect(delete_handler) + + # Process queued changes + logger.debug("Processing queued changes:") + for key, change in queue.items(): + logger.debug(f' {key}: {change}') + + # TODO: Optimize the creation of new Changes + changes = [] + for key, change in queue.items(): + object_type, pk = key + action, instance = change + if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE): + data = serialize_object(instance) + else: + data = None + + change = Change( + branch=branch, + action=action, + object_type=object_type, + object_id=pk, + data=data + ) + changes.append(change) + + Change.objects.bulk_create(changes) + + +def post_save_handler(sender, instance, branch, queue, created, **kwargs): + key = get_key_for_instance(instance) + if created: + # Creating a new object + logger.debug(f"Staging creation of {instance} under branch {branch}") + queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) + elif key in queue: + # Object has already been created/updated at least once + logger.debug(f"Updating staged value for {instance} under branch {branch}") + queue[key] = (queue[key][0], instance) + else: + # Modifying an existing object + logger.debug(f"Staging changes to {instance} (PK: {instance.pk}) under branch {branch}") + queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance) + + +def pre_delete_handler(sender, instance, branch, queue, **kwargs): + key = get_key_for_instance(instance) + if key in queue and queue[key][0] == 'create': + # Cancel the creation of a new object + logger.debug(f"Removing staged deletion of {instance} (PK: {instance.pk}) under branch {branch}") + del queue[key] + else: + # Delete an existing object + logger.debug(f"Staging deletion of {instance} (PK: {instance.pk}) under branch {branch}") + queue[key] = (ChangeActionChoices.ACTION_DELETE, instance) diff --git a/netbox/netbox/tests/test_staging.py b/netbox/netbox/tests/test_staging.py new file mode 100644 index 000000000..eb2217ed8 --- /dev/null +++ b/netbox/netbox/tests/test_staging.py @@ -0,0 +1,100 @@ +from django.test import TestCase + +from circuits.models import Provider, Circuit, CircuitType +from extras.models import Change, Branch +from netbox.staging import checkout + + +class StagingTestCase(TestCase): + + @classmethod + def setUpTestData(cls): + providers = ( + Provider(name='Provider A', slug='provider-a'), + Provider(name='Provider B', slug='provider-b'), + Provider(name='Provider C', slug='provider-c'), + ) + Provider.objects.bulk_create(providers) + + circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1') + + Circuit.objects.bulk_create(( + Circuit(provider=providers[0], cid='Circuit A1', type=circuit_type), + Circuit(provider=providers[0], cid='Circuit A2', type=circuit_type), + Circuit(provider=providers[0], cid='Circuit A3', type=circuit_type), + Circuit(provider=providers[1], cid='Circuit B1', type=circuit_type), + Circuit(provider=providers[1], cid='Circuit B2', type=circuit_type), + Circuit(provider=providers[1], cid='Circuit B3', type=circuit_type), + Circuit(provider=providers[2], cid='Circuit C1', type=circuit_type), + Circuit(provider=providers[2], cid='Circuit C2', type=circuit_type), + Circuit(provider=providers[2], cid='Circuit C3', type=circuit_type), + )) + + def test_object_creation(self): + branch = Branch.objects.create(name='Branch 1') + + with checkout(branch): + provider = Provider.objects.create(name='Provider D', slug='provider-d') + Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first()) + + self.assertEqual(Provider.objects.count(), 4) + self.assertEqual(Circuit.objects.count(), 10) + + self.assertEqual(Provider.objects.count(), 3) + self.assertEqual(Circuit.objects.count(), 9) + self.assertEqual(Change.objects.count(), 2) + + def test_object_modification(self): + branch = Branch.objects.create(name='Branch 1') + + with checkout(branch): + provider = Provider.objects.get(name='Provider A') + provider.name = 'Provider X' + provider.save() + circuit = Circuit.objects.get(cid='Circuit A1') + circuit.cid = 'Circuit X' + circuit.save() + + self.assertEqual(Provider.objects.count(), 3) + self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X') + self.assertEqual(Circuit.objects.count(), 9) + self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X') + + self.assertEqual(Provider.objects.count(), 3) + self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A') + self.assertEqual(Circuit.objects.count(), 9) + self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit A1') + self.assertEqual(Change.objects.count(), 2) + + def test_object_deletion(self): + branch = Branch.objects.create(name='Branch 1') + + with checkout(branch): + provider = Provider.objects.get(name='Provider A') + provider.circuits.all().delete() + provider.delete() + + self.assertEqual(Provider.objects.count(), 2) + self.assertEqual(Circuit.objects.count(), 6) + + self.assertEqual(Provider.objects.count(), 3) + self.assertEqual(Circuit.objects.count(), 9) + self.assertEqual(Change.objects.count(), 4) + + def test_create_update_delete_clean(self): + branch = Branch.objects.create(name='Branch 1') + + with checkout(branch): + + # Create a new object + provider = Provider.objects.create(name='Provider D', slug='provider-d') + provider.save() + + # Update it + provider.comments = 'Another change' + provider.save() + + # Delete it + provider.delete() + + self.assertFalse(Change.objects.exists()) diff --git a/netbox/utilities/utils.py b/netbox/utilities/utils.py index a5bccfbea..de261945c 100644 --- a/netbox/utilities/utils.py +++ b/netbox/utilities/utils.py @@ -6,7 +6,8 @@ from decimal import Decimal from itertools import count, groupby import bleach -from django.core.serializers import serialize +from django.contrib.contenttypes.models import ContentType +from django.core import serializers from django.db.models import Count, OuterRef, Subquery from django.db.models.functions import Coalesce from django.http import QueryDict @@ -142,7 +143,7 @@ def serialize_object(obj, extra=None): can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are implicitly excluded. """ - json_str = serialize('json', [obj]) + json_str = serializers.serialize('json', [obj]) data = json.loads(json_str)[0]['fields'] # Exclude any MPTTModel fields @@ -172,6 +173,20 @@ def serialize_object(obj, extra=None): return data +def deserialize_object(model, fields, pk=None): + content_type = ContentType.objects.get_for_model(model) + if 'custom_fields' in fields: + fields['custom_field_data'] = fields.pop('custom_fields') + data = { + 'model': '.'.join(content_type.natural_key()), + 'pk': pk, + 'fields': fields, + } + instance = list(serializers.deserialize('python', [data]))[0] + + return instance + + def dict_to_filter_params(d, prefix=''): """ Translate a dictionary of attributes to a nested set of parameters suitable for QuerySet filtering. For example: