Refactor counter logic to leverage the registry

This commit is contained in:
Jeremy Stretch 2023-07-24 16:03:54 -04:00
parent 073ad5329a
commit 848182c4ea
4 changed files with 100 additions and 93 deletions

View File

@ -1,4 +1,4 @@
from django.apps import AppConfig, apps
from django.apps import AppConfig
from netbox import denormalized
@ -10,8 +10,7 @@ class DCIMConfig(AppConfig):
def ready(self):
from . import signals, search
from .models import CableTermination, Device
from utilities.counter import connect_counters
from utilities.counters import connect_counters
# Register denormalized fields
denormalized.register(CableTermination, '_device', {
@ -27,4 +26,5 @@ class DCIMConfig(AppConfig):
'_site': 'site',
})
connect_counters([Device,])
# Register counters
connect_counters(Device)

View File

@ -1,87 +0,0 @@
from django.apps import apps
from django.db.models import F
from django.db.models.signals import post_delete, post_save
from functools import partial
from netbox.registry import registry
from .fields import CounterCacheField
def post_save_receiver_counter(counter_instance, sender, instance, created, **kwargs):
if created:
counter_instance.adjust_count(counter_instance.parent_id(instance), 1)
return
# not created so check if field has changed
field_name = f"{counter_instance.foreign_key_field.name}_id"
if field_name in instance.tracker:
new_value = getattr(instance, field_name, None)
old_value = instance.tracker.get(field_name)
if (new_value is not None) and (new_value != old_value):
counter_instance.adjust_count(new_value, 1)
counter_instance.adjust_count(old_value, -1)
def post_delete_receiver_counter(counter_instance, sender, instance, **kwargs):
counter_instance.adjust_count(counter_instance.parent_id(instance), -1)
class Counter:
"""
Used with CounterCacheField to add signals to track related model counts.
"""
counter_name = None
foreign_key_field = None
child_model = None
parent_model = None
def __init__(self, counter_name, foreign_key_field):
self.counter_name = counter_name
self.foreign_key_field = foreign_key_field.field
self.child_model = self.foreign_key_field.model
self.parent_model = self.foreign_key_field.related_model
# add the field to be tracked for changes in case of update
change_tracking_fields = registry['counter_fields'][self.child_model]
change_tracking_fields[f"{self.foreign_key_field.name}_id"] = counter_name
self.connect()
def connect(self):
"""
Hook up post_save, post_delete signal handlers to the fk field to change the count
"""
name = f"{self.parent_model._meta.model_name}.{self.child_model._meta.model_name}.{self.foreign_key_field.name}"
counted_name = f"{name}-{self.counter_name}"
post_save_receiver = partial(post_save_receiver_counter, counter_instance=self)
post_save.connect(
post_save_receiver, sender=self.child_model, weak=False, dispatch_uid=f'{counted_name}_post_save'
)
post_delete_receiver = partial(post_delete_receiver_counter, counter_instance=self)
post_delete.connect(
post_delete_receiver,
sender=self.child_model,
weak=False,
dispatch_uid=f'{counted_name}_post_delete',
)
def parent_id(self, child):
return getattr(child, self.foreign_key_field.attname)
def set_counter_field(self, parent_id, value):
return self.parent_model.objects.filter(pk=parent_id).update(**{self.counter_name: value})
def adjust_count(self, parent_id, amount):
return self.set_counter_field(parent_id, F(self.counter_name) + amount)
def connect_counters(models):
for model in models:
fields = model._meta.get_fields()
for field in fields:
if type(field) is CounterCacheField:
to_model = apps.get_model(field.to_model_name)
to_field = getattr(to_model, field.to_field_name)
Counter(field.name, to_field)

View File

@ -0,0 +1,93 @@
from django.apps import apps
from django.db.models import F
from django.db.models.signals import post_delete, post_save
from netbox.registry import registry
from .fields import CounterCacheField
def get_counters_for_model(model):
"""
Return field mappings for all counters registered to the given model.
"""
return registry['counter_fields'][model].items()
def update_counter(model, pk, counter_name, value):
"""
Increment or decrement a counter field on an object identified by its model and primary key (PK). Positive values
will increment; negative values will decrement.
"""
model.objects.filter(pk=pk).update(
**{counter_name: F(counter_name) + value}
)
#
# Signal handlers
#
def post_save_receiver(sender, instance, **kwargs):
"""
Update counter fields on related objects when a TrackingModelMixin subclass is created or modified.
"""
for field_name, counter_name in get_counters_for_model(sender):
parent_model = sender._meta.get_field(field_name).related_model
new_pk = getattr(instance, field_name, None)
old_pk = instance.tracker.get(field_name) if field_name in instance.tracker else None
# Update the counters on the old and/or new parents as needed
if old_pk is not None:
update_counter(parent_model, old_pk, counter_name, -1)
if new_pk is not None:
update_counter(parent_model, new_pk, counter_name, 1)
def post_delete_receiver(sender, instance, **kwargs):
"""
Update counter fields on related objects when a TrackingModelMixin subclass is deleted.
"""
for field_name, counter_name in get_counters_for_model(sender):
parent_model = sender._meta.get_field(field_name).related_model
parent_pk = getattr(instance, field_name, None)
# Decrement the parent's counter by one
if parent_pk is not None:
update_counter(parent_model, parent_pk, counter_name, -1)
#
# Registration
#
def connect_counters(*models):
"""
Register counter fields and connect post_save & post_delete signal handlers for the affected models.
"""
for model in models:
# Find all CounterCacheFields on the model
counter_fields = [
field for field in model._meta.get_fields() if type(field) is CounterCacheField
]
for field in counter_fields:
to_model = apps.get_model(field.to_model_name)
# Register the counter in the registry
change_tracking_fields = registry['counter_fields'][to_model]
change_tracking_fields[f"{field.to_field_name}_id"] = field.name
# Connect the post_save and post_delete handlers
post_save.connect(
post_save_receiver,
sender=to_model,
weak=False,
dispatch_uid=f'{model._meta.label}.{field.name}'
)
post_delete.connect(
post_delete_receiver,
sender=to_model,
weak=False,
dispatch_uid=f'{model._meta.label}.{field.name}'
)

View File

@ -7,6 +7,7 @@ class VirtualizationConfig(AppConfig):
def ready(self):
from . import search
from .models import VirtualMachine
from utilities.counter import connect_counters
from utilities.counters import connect_counters
connect_counters([VirtualMachine,])
# Register counters
connect_counters(VirtualMachine)