diff --git a/netbox/dcim/apps.py b/netbox/dcim/apps.py index 2323f55af..cc4c65f93 100644 --- a/netbox/dcim/apps.py +++ b/netbox/dcim/apps.py @@ -1,4 +1,4 @@ -from django.apps import AppConfig, apps +from django.apps import AppConfig from netbox import denormalized @@ -10,8 +10,7 @@ class DCIMConfig(AppConfig): def ready(self): from . import signals, search from .models import CableTermination, Device - - from utilities.counter import connect_counters + from utilities.counters import connect_counters # Register denormalized fields denormalized.register(CableTermination, '_device', { @@ -27,4 +26,5 @@ class DCIMConfig(AppConfig): '_site': 'site', }) - connect_counters([Device,]) + # Register counters + connect_counters(Device) diff --git a/netbox/utilities/counter.py b/netbox/utilities/counter.py deleted file mode 100644 index 47b47feec..000000000 --- a/netbox/utilities/counter.py +++ /dev/null @@ -1,87 +0,0 @@ -from django.apps import apps -from django.db.models import F -from django.db.models.signals import post_delete, post_save -from functools import partial - -from netbox.registry import registry -from .fields import CounterCacheField - - -def post_save_receiver_counter(counter_instance, sender, instance, created, **kwargs): - if created: - counter_instance.adjust_count(counter_instance.parent_id(instance), 1) - return - - # not created so check if field has changed - field_name = f"{counter_instance.foreign_key_field.name}_id" - if field_name in instance.tracker: - new_value = getattr(instance, field_name, None) - old_value = instance.tracker.get(field_name) - if (new_value is not None) and (new_value != old_value): - counter_instance.adjust_count(new_value, 1) - counter_instance.adjust_count(old_value, -1) - - -def post_delete_receiver_counter(counter_instance, sender, instance, **kwargs): - counter_instance.adjust_count(counter_instance.parent_id(instance), -1) - - -class Counter: - """ - Used with CounterCacheField to add signals to track related model counts. - """ - counter_name = None - foreign_key_field = None - child_model = None - parent_model = None - - def __init__(self, counter_name, foreign_key_field): - self.counter_name = counter_name - self.foreign_key_field = foreign_key_field.field - self.child_model = self.foreign_key_field.model - self.parent_model = self.foreign_key_field.related_model - - # add the field to be tracked for changes in case of update - change_tracking_fields = registry['counter_fields'][self.child_model] - change_tracking_fields[f"{self.foreign_key_field.name}_id"] = counter_name - - self.connect() - - def connect(self): - """ - Hook up post_save, post_delete signal handlers to the fk field to change the count - """ - name = f"{self.parent_model._meta.model_name}.{self.child_model._meta.model_name}.{self.foreign_key_field.name}" - counted_name = f"{name}-{self.counter_name}" - - post_save_receiver = partial(post_save_receiver_counter, counter_instance=self) - post_save.connect( - post_save_receiver, sender=self.child_model, weak=False, dispatch_uid=f'{counted_name}_post_save' - ) - - post_delete_receiver = partial(post_delete_receiver_counter, counter_instance=self) - post_delete.connect( - post_delete_receiver, - sender=self.child_model, - weak=False, - dispatch_uid=f'{counted_name}_post_delete', - ) - - def parent_id(self, child): - return getattr(child, self.foreign_key_field.attname) - - def set_counter_field(self, parent_id, value): - return self.parent_model.objects.filter(pk=parent_id).update(**{self.counter_name: value}) - - def adjust_count(self, parent_id, amount): - return self.set_counter_field(parent_id, F(self.counter_name) + amount) - - -def connect_counters(models): - for model in models: - fields = model._meta.get_fields() - for field in fields: - if type(field) is CounterCacheField: - to_model = apps.get_model(field.to_model_name) - to_field = getattr(to_model, field.to_field_name) - Counter(field.name, to_field) diff --git a/netbox/utilities/counters.py b/netbox/utilities/counters.py new file mode 100644 index 000000000..ee6865ca2 --- /dev/null +++ b/netbox/utilities/counters.py @@ -0,0 +1,93 @@ +from django.apps import apps +from django.db.models import F +from django.db.models.signals import post_delete, post_save + +from netbox.registry import registry +from .fields import CounterCacheField + + +def get_counters_for_model(model): + """ + Return field mappings for all counters registered to the given model. + """ + return registry['counter_fields'][model].items() + + +def update_counter(model, pk, counter_name, value): + """ + Increment or decrement a counter field on an object identified by its model and primary key (PK). Positive values + will increment; negative values will decrement. + """ + model.objects.filter(pk=pk).update( + **{counter_name: F(counter_name) + value} + ) + + +# +# Signal handlers +# + +def post_save_receiver(sender, instance, **kwargs): + """ + Update counter fields on related objects when a TrackingModelMixin subclass is created or modified. + """ + for field_name, counter_name in get_counters_for_model(sender): + parent_model = sender._meta.get_field(field_name).related_model + new_pk = getattr(instance, field_name, None) + old_pk = instance.tracker.get(field_name) if field_name in instance.tracker else None + + # Update the counters on the old and/or new parents as needed + if old_pk is not None: + update_counter(parent_model, old_pk, counter_name, -1) + if new_pk is not None: + update_counter(parent_model, new_pk, counter_name, 1) + + +def post_delete_receiver(sender, instance, **kwargs): + """ + Update counter fields on related objects when a TrackingModelMixin subclass is deleted. + """ + for field_name, counter_name in get_counters_for_model(sender): + parent_model = sender._meta.get_field(field_name).related_model + parent_pk = getattr(instance, field_name, None) + + # Decrement the parent's counter by one + if parent_pk is not None: + update_counter(parent_model, parent_pk, counter_name, -1) + + +# +# Registration +# + +def connect_counters(*models): + """ + Register counter fields and connect post_save & post_delete signal handlers for the affected models. + """ + for model in models: + + # Find all CounterCacheFields on the model + counter_fields = [ + field for field in model._meta.get_fields() if type(field) is CounterCacheField + ] + + for field in counter_fields: + to_model = apps.get_model(field.to_model_name) + + # Register the counter in the registry + change_tracking_fields = registry['counter_fields'][to_model] + change_tracking_fields[f"{field.to_field_name}_id"] = field.name + + # Connect the post_save and post_delete handlers + post_save.connect( + post_save_receiver, + sender=to_model, + weak=False, + dispatch_uid=f'{model._meta.label}.{field.name}' + ) + post_delete.connect( + post_delete_receiver, + sender=to_model, + weak=False, + dispatch_uid=f'{model._meta.label}.{field.name}' + ) diff --git a/netbox/virtualization/apps.py b/netbox/virtualization/apps.py index 6dae005c7..8db943ea1 100644 --- a/netbox/virtualization/apps.py +++ b/netbox/virtualization/apps.py @@ -7,6 +7,7 @@ class VirtualizationConfig(AppConfig): def ready(self): from . import search from .models import VirtualMachine - from utilities.counter import connect_counters + from utilities.counters import connect_counters - connect_counters([VirtualMachine,]) + # Register counters + connect_counters(VirtualMachine)