Convert checkout() context manager to a class

This commit is contained in:
jeremystretch 2022-11-10 14:00:04 -05:00
parent c2e6853031
commit 5fb96a6b6f
3 changed files with 69 additions and 61 deletions

View File

@ -36,6 +36,9 @@ class Branch(ChangeLoggedModel):
class Meta: class Meta:
ordering = ('name',) ordering = ('name',)
def __str__(self):
return f'{self.name} ({self.pk})'
def merge(self): def merge(self):
logger.info(f'Merging changes in branch {self}') logger.info(f'Merging changes in branch {self}')
with transaction.atomic(): with transaction.atomic():
@ -87,15 +90,15 @@ class Change(ChangeLoggedModel):
if self.action == ChangeActionChoices.ACTION_CREATE: if self.action == ChangeActionChoices.ACTION_CREATE:
instance = deserialize_object(model, self.data, pk=pk) instance = deserialize_object(model, self.data, pk=pk)
logger.info(f'Creating {model} {instance}') logger.info(f'Creating {model._meta.verbose_name} {instance}')
instance.save() instance.save()
if self.action == ChangeActionChoices.ACTION_UPDATE: if self.action == ChangeActionChoices.ACTION_UPDATE:
instance = deserialize_object(model, self.data, pk=pk) instance = deserialize_object(model, self.data, pk=pk)
logger.info(f'Updating {model} {instance}') logger.info(f'Updating {model._meta.verbose_name} {instance}')
instance.save() instance.save()
if self.action == ChangeActionChoices.ACTION_DELETE: if self.action == ChangeActionChoices.ACTION_DELETE:
instance = model.objects.get(pk=self.object_id) instance = model.objects.get(pk=self.object_id)
logger.info(f'Deleting {model} {instance}') logger.info(f'Deleting {model._meta.verbose_name} {instance}')
instance.delete() instance.delete()

View File

@ -1,5 +1,4 @@
import logging import logging
from contextlib import contextmanager
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import transaction from django.db import transaction
@ -7,7 +6,6 @@ from django.db.models.signals import pre_delete, post_save
from extras.choices import ChangeActionChoices from extras.choices import ChangeActionChoices
from extras.models import Change from extras.models import Change
from utilities.exceptions import AbortTransaction
from utilities.utils import serialize_object, shallow_compare_dict from utilities.utils import serialize_object, shallow_compare_dict
logger = logging.getLogger('netbox.staging') logger = logging.getLogger('netbox.staging')
@ -28,44 +26,52 @@ def get_key_for_instance(instance):
return object_type, instance.pk return object_type, instance.pk
@contextmanager class checkout:
def checkout(branch):
queue = {} def __init__(self, branch):
self.branch = branch
self.queue = {}
def save_handler(sender, instance, **kwargs): def __enter__(self):
return post_save_handler(sender, instance, branch=branch, queue=queue, **kwargs)
def delete_handler(sender, instance, **kwargs): # Disable autocommit to effect a new transaction
return pre_delete_handler(sender, instance, branch=branch, queue=queue, **kwargs) logger.debug(f"Entering transaction for {self.branch}")
self._autocommit = transaction.get_autocommit()
# Connect signal handlers transaction.set_autocommit(False)
post_save.connect(save_handler)
pre_delete.connect(delete_handler)
try: # Apply any existing Changes assigned to this Branch
with transaction.atomic(): changes = self.branch.changes.all()
yield if changes.exists():
raise AbortTransaction() logger.debug(f"Applying {changes.count()} pre-staged changes...")
for change in changes:
change.apply()
else:
logger.debug("No pre-staged changes found")
# Roll back the transaction # Connect signal handlers
except AbortTransaction: logger.debug("Connecting signal handlers")
pass post_save.connect(self.post_save_handler)
pre_delete.connect(self.pre_delete_handler)
finally: def __exit__(self, exc_type, exc_val, exc_tb):
# Roll back the transaction to return the database to its original state
logger.debug("Rolling back transaction")
transaction.rollback()
logger.debug(f"Restoring autocommit state {self._autocommit}")
transaction.set_autocommit(self._autocommit)
# Disconnect signal handlers # Disconnect signal handlers
post_save.disconnect(save_handler) logger.debug("Disconnecting signal handlers")
pre_delete.disconnect(delete_handler) post_save.disconnect(self.post_save_handler)
pre_delete.disconnect(self.pre_delete_handler)
# Process queued changes # Process queued changes
logger.debug("Processing queued changes:")
for key, change in queue.items():
logger.debug(f' {key}: {change}')
# TODO: Optimize the creation of new Changes
changes = [] changes = []
for key, change in queue.items(): logger.debug(f"Processing {len(self.queue)} queued changes:")
for key, change in self.queue.items():
logger.debug(f' {key}: {change}')
object_type, pk = key object_type, pk = key
action, instance = change action, instance = change
if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE): if action in (ChangeActionChoices.ACTION_CREATE, ChangeActionChoices.ACTION_UPDATE):
@ -74,7 +80,7 @@ def checkout(branch):
data = None data = None
change = Change( change = Change(
branch=branch, branch=self.branch,
action=action, action=action,
object_type=object_type, object_type=object_type,
object_id=pk, object_id=pk,
@ -84,30 +90,30 @@ def checkout(branch):
Change.objects.bulk_create(changes) Change.objects.bulk_create(changes)
def post_save_handler(self, sender, instance, created, **kwargs):
key = get_key_for_instance(instance)
object_type = instance._meta.verbose_name
def post_save_handler(sender, instance, branch, queue, created, **kwargs): if created:
key = get_key_for_instance(instance) # Creating a new object
if created: logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance}")
# Creating a new object self.queue[key] = (ChangeActionChoices.ACTION_CREATE, instance)
logger.debug(f"Staging creation of {instance} under branch {branch}") elif key in self.queue:
queue[key] = (ChangeActionChoices.ACTION_CREATE, instance) # Object has already been created/updated at least once
elif key in queue: logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance}")
# Object has already been created/updated at least once self.queue[key] = (self.queue[key][0], instance)
logger.debug(f"Updating staged value for {instance} under branch {branch}") else:
queue[key] = (queue[key][0], instance) # Modifying an existing object
else: logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
# Modifying an existing object self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
logger.debug(f"Staging changes to {instance} (PK: {instance.pk}) under branch {branch}")
queue[key] = (ChangeActionChoices.ACTION_UPDATE, instance)
def pre_delete_handler(self, sender, instance, **kwargs):
def pre_delete_handler(sender, instance, branch, queue, **kwargs): key = get_key_for_instance(instance)
key = get_key_for_instance(instance) if key in self.queue and self.queue[key][0] == 'create':
if key in queue and queue[key][0] == 'create': # Cancel the creation of a new object
# Cancel the creation of a new object logger.debug(f"[{self.branch}] Removing staged deletion of {instance} (PK: {instance.pk})")
logger.debug(f"Removing staged deletion of {instance} (PK: {instance.pk}) under branch {branch}") del self.queue[key]
del queue[key] else:
else: # Delete an existing object
# Delete an existing object logger.debug(f"[{self.branch}] Staging deletion of {instance} (PK: {instance.pk})")
logger.debug(f"Staging deletion of {instance} (PK: {instance.pk}) under branch {branch}") self.queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)
queue[key] = (ChangeActionChoices.ACTION_DELETE, instance)

View File

@ -1,14 +1,13 @@
from django.test import TestCase from django.test import TransactionTestCase
from circuits.models import Provider, Circuit, CircuitType from circuits.models import Provider, Circuit, CircuitType
from extras.models import Change, Branch from extras.models import Change, Branch
from netbox.staging import checkout from netbox.staging import checkout
class StagingTestCase(TestCase): class StagingTestCase(TransactionTestCase):
@classmethod def setUp(self):
def setUpTestData(cls):
providers = ( providers = (
Provider(name='Provider A', slug='provider-a'), Provider(name='Provider A', slug='provider-a'),
Provider(name='Provider B', slug='provider-b'), Provider(name='Provider B', slug='provider-b'),