Convert checkout() context manager to a class

This commit is contained in:
jeremystretch 2022-11-10 14:00:04 -05:00
parent c2e6853031
commit 5fb96a6b6f
3 changed files with 69 additions and 61 deletions

View File

@ -36,6 +36,9 @@ class Branch(ChangeLoggedModel):
class Meta: class Meta:
ordering = ('name',) ordering = ('name',)
def __str__(self):
return f'{self.name} ({self.pk})'
def merge(self): def merge(self):
logger.info(f'Merging changes in branch {self}') logger.info(f'Merging changes in branch {self}')
with transaction.atomic(): with transaction.atomic():
@ -87,15 +90,15 @@ class Change(ChangeLoggedModel):
if self.action == ChangeActionChoices.ACTION_CREATE: if self.action == ChangeActionChoices.ACTION_CREATE:
instance = deserialize_object(model, self.data, pk=pk) instance = deserialize_object(model, self.data, pk=pk)
logger.info(f'Creating {model} {instance}') logger.info(f'Creating {model._meta.verbose_name} {instance}')
instance.save() instance.save()
if self.action == ChangeActionChoices.ACTION_UPDATE: if self.action == ChangeActionChoices.ACTION_UPDATE:
instance = deserialize_object(model, self.data, pk=pk) instance = deserialize_object(model, self.data, pk=pk)
logger.info(f'Updating {model} {instance}') logger.info(f'Updating {model._meta.verbose_name} {instance}')
instance.save() instance.save()
if self.action == ChangeActionChoices.ACTION_DELETE: if self.action == ChangeActionChoices.ACTION_DELETE:
instance = model.objects.get(pk=self.object_id) instance = model.objects.get(pk=self.object_id)
logger.info(f'Deleting {model} {instance}') logger.info(f'Deleting {model._meta.verbose_name} {instance}')
instance.delete() instance.delete()

View File

@ -1,5 +1,4 @@
import logging import logging
from contextlib import contextmanager
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import transaction from django.db import transaction
@ -7,7 +6,6 @@ from django.db.models.signals import pre_delete, post_save
from extras.choices import ChangeActionChoices from extras.choices import ChangeActionChoices
from extras.models import Change from extras.models import Change
from utilities.exceptions import AbortTransaction
from utilities.utils import serialize_object, shallow_compare_dict from utilities.utils import serialize_object, shallow_compare_dict
logger = logging.getLogger('netbox.staging') logger = logging.getLogger('netbox.staging')
@ -28,44 +26,52 @@ def get_key_for_instance(instance):
return object_type, instance.pk return object_type, instance.pk
@contextmanager class checkout:
def checkout(branch):
queue = {} def __init__(self, branch):
self.branch = branch
self.queue = {}
def save_handler(sender, instance, **kwargs): def __enter__(self):
return post_save_handler(sender, instance, branch=branch, queue=queue, **kwargs)
def delete_handler(sender, instance, **kwargs): # Disable autocommit to effect a new transaction
return pre_delete_handler(sender, instance, branch=branch, queue=queue, **kwargs) logger.debug(f"Entering transaction for {self.branch}")
self._autocommit = transaction.get_autocommit()
transaction.set_autocommit(False)
# Apply any existing Changes assigned to this Branch
changes = self.branch.changes.all()
if changes.exists():
logger.debug(f"Applying {changes.count()} pre-staged changes...")
for change in changes:
change.apply()
else:
logger.debug("No pre-staged changes found")
# Connect signal handlers # Connect signal handlers
post_save.connect(save_handler) logger.debug("Connecting signal handlers")
pre_delete.connect(delete_handler) post_save.connect(self.post_save_handler)
pre_delete.connect(self.pre_delete_handler)
try: def __exit__(self, exc_type, exc_val, exc_tb):
with transaction.atomic():
yield
raise AbortTransaction()
# Roll back the transaction # Roll back the transaction to return the database to its original state
except AbortTransaction: logger.debug("Rolling back transaction")
pass transaction.rollback()
logger.debug(f"Restoring autocommit state {self._autocommit}")
finally: transaction.set_autocommit(self._autocommit)
# Disconnect signal handlers # Disconnect signal handlers
post_save.disconnect(save_handler) logger.debug("Disconnecting signal handlers")
pre_delete.disconnect(delete_handler) post_save.disconnect(self.post_save_handler)
pre_delete.disconnect(self.pre_delete_handler)
# Process queued changes # Process queued changes
logger.debug("Processing queued changes:")
for key, change in queue.items():
logger.debug(f' {key}: {change}')
# TODO: Optimize the creation of new Changes
changes = [] changes = []
for key, change in queue.items(): logger.debug(f"Processing {len(self.queue)} queued changes:")
for key, change in self.queue.items():
logger.debug(f' {key}: {change}')
object_type, pk = key object_type, pk = key
action, instance = change action, instance = change
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE): if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
@ -74,7 +80,7 @@ def checkout(branch):
data = None data = None
change = Change( change = Change(
branch=branch, branch=self.branch,
action=action, action=action,
object_type=object_type, object_type=object_type,
object_id=pk, object_id=pk,
@ -84,30 +90,30 @@ def checkout(branch):
Change.objects.bulk_create(changes) Change.objects.bulk_create(changes)
def post_save_handler(self, sender, instance, created, **kwargs):
def post_save_handler(sender, instance, branch, queue, created, **kwargs):
key = get_key_for_instance(instance) key = get_key_for_instance(instance)
object_type = instance._meta.verbose_name
if created: if created:
# Creating a new object # Creating a new object
logger.debug(f"Staging creation of {instance} under branch {branch}") logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance}")
queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance)
elif key in queue: elif key in self.queue:
# Object has already been created/updated at least once # Object has already been created/updated at least once
logger.debug(f"Updating staged value for {instance} under branch {branch}") logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance}")
queue[key] = (queue[key][0], instance) self.queue[key] = (self.queue[key][0], instance)
else: else:
# Modifying an existing object # Modifying an existing object
logger.debug(f"Staging changes to {instance} (PK: {instance.pk}) under branch {branch}") logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance) self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
def pre_delete_handler(self, sender, instance, **kwargs):
def pre_delete_handler(sender, instance, branch, queue, **kwargs):
key = get_key_for_instance(instance) key = get_key_for_instance(instance)
if key in queue and queue[key][0] == 'create': if key in self.queue and self.queue[key][0] == 'create':
# Cancel the creation of a new object # Cancel the creation of a new object
logger.debug(f"Removing staged deletion of {instance} (PK: {instance.pk}) under branch {branch}") logger.debug(f"[{self.branch}] Removing staged deletion of {instance} (PK: {instance.pk})")
del queue[key] del self.queue[key]
else: else:
# Delete an existing object # Delete an existing object
logger.debug(f"Staging deletion of {instance} (PK: {instance.pk}) under branch {branch}") logger.debug(f"[{self.branch}] Staging deletion of {instance} (PK: {instance.pk})")
queue[key] = (ChangeActionChoices.ACTION_DELETE, instance) self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)

View File

@ -1,14 +1,13 @@
from django.test import TestCase from django.test import TransactionTestCase
from circuits.models import Provider, Circuit, CircuitType from circuits.models import Provider, Circuit, CircuitType
from extras.models import Change, Branch from extras.models import Change, Branch
from netbox.staging import checkout from netbox.staging import checkout
class StagingTestCase(TestCase): class StagingTestCase(TransactionTestCase):
@classmethod def setUp(self):
def setUpTestData(cls):
providers = ( providers = (
Provider(name='Provider A', slug='provider-a'), Provider(name='Provider A', slug='provider-a'),
Provider(name='Provider B', slug='provider-b'), Provider(name='Provider B', slug='provider-b'),