diff --git a/netbox/extras/models/staging.py b/netbox/extras/models/staging.py index b5564f428..8a5e2f40f 100644 --- a/netbox/extras/models/staging.py +++ b/netbox/extras/models/staging.py @@ -36,6 +36,9 @@ class Branch(ChangeLoggedModel): class Meta: ordering = ('name',) + def __str__(self): + return f'{self.name} ({self.pk})' + def merge(self): logger.info(f'Merging changes in branch {self}') with transaction.atomic(): @@ -87,15 +90,15 @@ class Change(ChangeLoggedModel): if self.action == ChangeActionChoices.ACTION_CREATE: instance = deserialize_object(model, self.data, pk=pk) - logger.info(f'Creating {model} {instance}') + logger.info(f'Creating {model._meta.verbose_name} {instance}') instance.save() if self.action == ChangeActionChoices.ACTION_UPDATE: instance = deserialize_object(model, self.data, pk=pk) - logger.info(f'Updating {model} {instance}') + logger.info(f'Updating {model._meta.verbose_name} {instance}') instance.save() if self.action == ChangeActionChoices.ACTION_DELETE: instance = model.objects.get(pk=self.object_id) - logger.info(f'Deleting {model} {instance}') + logger.info(f'Deleting {model._meta.verbose_name} {instance}') instance.delete() diff --git a/netbox/netbox/staging.py b/netbox/netbox/staging.py index 782a7fb40..4cd03cd06 100644 --- a/netbox/netbox/staging.py +++ b/netbox/netbox/staging.py @@ -1,5 +1,4 @@ import logging -from contextlib import contextmanager from django.contrib.contenttypes.models import ContentType from django.db import transaction @@ -7,7 +6,6 @@ from django.db.models.signals import pre_delete, post_save from extras.choices import ChangeActionChoices from extras.models import Change -from utilities.exceptions import AbortTransaction from utilities.utils import serialize_object, shallow_compare_dict logger = logging.getLogger('netbox.staging') @@ -28,44 +26,52 @@ def get_key_for_instance(instance): return object_type, instance.pk -@contextmanager -def checkout(branch): +class checkout: - queue = {} + def __init__(self, branch): + self.branch = branch + self.queue = {} - def save_handler(sender, instance, **kwargs): - return post_save_handler(sender, instance, branch=branch, queue=queue, **kwargs) + def __enter__(self): - def delete_handler(sender, instance, **kwargs): - return pre_delete_handler(sender, instance, branch=branch, queue=queue, **kwargs) + # Disable autocommit to effect a new transaction + logger.debug(f"Entering transaction for {self.branch}") + self._autocommit = transaction.get_autocommit() - # Connect signal handlers - post_save.connect(save_handler) - pre_delete.connect(delete_handler) + transaction.set_autocommit(False) - try: - with transaction.atomic(): - yield - raise AbortTransaction() + # Apply any existing Changes assigned to this Branch + changes = self.branch.changes.all() + if changes.exists(): + logger.debug(f"Applying {changes.count()} pre-staged changes...") + for change in changes: + change.apply() + else: + logger.debug("No pre-staged changes found") - # Roll back the transaction - except AbortTransaction: - pass + # Connect signal handlers + logger.debug("Connecting signal handlers") + post_save.connect(self.post_save_handler) + pre_delete.connect(self.pre_delete_handler) - finally: + def __exit__(self, exc_type, exc_val, exc_tb): + + # Roll back the transaction to return the database to its original state + logger.debug("Rolling back transaction") + transaction.rollback() + logger.debug(f"Restoring autocommit state {self._autocommit}") + transaction.set_autocommit(self._autocommit) # Disconnect signal handlers - post_save.disconnect(save_handler) - pre_delete.disconnect(delete_handler) + logger.debug("Disconnecting signal handlers") + post_save.disconnect(self.post_save_handler) + pre_delete.disconnect(self.pre_delete_handler) # Process queued changes - logger.debug("Processing queued changes:") - for key, change in queue.items(): - logger.debug(f' {key}: {change}') - - # TODO: Optimize the creation of new Changes changes = [] - for key, change in queue.items(): + logger.debug(f"Processing {len(self.queue)} queued changes:") + for key, change in self.queue.items(): + logger.debug(f' {key}: {change}') object_type, pk = key action, instance = change if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE): @@ -74,7 +80,7 @@ def checkout(branch): data = None change = Change( - branch=branch, + branch=self.branch, action=action, object_type=object_type, object_id=pk, @@ -84,30 +90,30 @@ def checkout(branch): Change.objects.bulk_create(changes) + def post_save_handler(self, sender, instance, created, **kwargs): + key = get_key_for_instance(instance) + object_type = instance._meta.verbose_name -def post_save_handler(sender, instance, branch, queue, created, **kwargs): - key = get_key_for_instance(instance) - if created: - # Creating a new object - logger.debug(f"Staging creation of {instance} under branch {branch}") - queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) - elif key in queue: - # Object has already been created/updated at least once - logger.debug(f"Updating staged value for {instance} under branch {branch}") - queue[key] = (queue[key][0], instance) - else: - # Modifying an existing object - logger.debug(f"Staging changes to {instance} (PK: {instance.pk}) under branch {branch}") - queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance) + if created: + # Creating a new object + logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance}") + self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) + elif key in self.queue: + # Object has already been created/updated at least once + logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance}") + self.queue[key] = (self.queue[key][0], instance) + else: + # Modifying an existing object + logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})") + self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance) - -def pre_delete_handler(sender, instance, branch, queue, **kwargs): - key = get_key_for_instance(instance) - if key in queue and queue[key][0] == 'create': - # Cancel the creation of a new object - logger.debug(f"Removing staged deletion of {instance} (PK: {instance.pk}) under branch {branch}") - del queue[key] - else: - # Delete an existing object - logger.debug(f"Staging deletion of {instance} (PK: {instance.pk}) under branch {branch}") - queue[key] = (ChangeActionChoices.ACTION_DELETE, instance) + def pre_delete_handler(self, sender, instance, **kwargs): + key = get_key_for_instance(instance) + if key in self.queue and self.queue[key][0] == 'create': + # Cancel the creation of a new object + logger.debug(f"[{self.branch}] Removing staged deletion of {instance} (PK: {instance.pk})") + del self.queue[key] + else: + # Delete an existing object + logger.debug(f"[{self.branch}] Staging deletion of {instance} (PK: {instance.pk})") + self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance) diff --git a/netbox/netbox/tests/test_staging.py b/netbox/netbox/tests/test_staging.py index eb2217ed8..cb18f66c9 100644 --- a/netbox/netbox/tests/test_staging.py +++ b/netbox/netbox/tests/test_staging.py @@ -1,14 +1,13 @@ -from django.test import TestCase +from django.test import TransactionTestCase from circuits.models import Provider, Circuit, CircuitType from extras.models import Change, Branch from netbox.staging import checkout -class StagingTestCase(TestCase): +class StagingTestCase(TransactionTestCase): - @classmethod - def setUpTestData(cls): + def setUp(self): providers = ( Provider(name='Provider A', slug='provider-a'), Provider(name='Provider B', slug='provider-b'),