From 296e62579f1febb76cb93bf624663bc7962b7521 Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Mon, 24 Jul 2023 11:38:49 -0400 Subject: [PATCH] Enable dynamic resolution of counter field mappings for the management command --- netbox/netbox/registry.py | 2 +- netbox/utilities/counter.py | 2 +- .../commands/calculate_cached_counts.py | 72 +++++++++++-------- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/netbox/netbox/registry.py b/netbox/netbox/registry.py index 736fca69a..21a869001 100644 --- a/netbox/netbox/registry.py +++ b/netbox/netbox/registry.py @@ -21,7 +21,7 @@ class Registry(dict): # Initialize the global registry registry = Registry({ - 'counter_fields': collections.defaultdict(set), + 'counter_fields': collections.defaultdict(dict), 'data_backends': dict(), 'denormalized_fields': collections.defaultdict(list), 'model_features': dict(), diff --git a/netbox/utilities/counter.py b/netbox/utilities/counter.py index 9f96a6c04..8cb555c71 100644 --- a/netbox/utilities/counter.py +++ b/netbox/utilities/counter.py @@ -43,7 +43,7 @@ class Counter: # add the field to be tracked for changes in case of update change_tracking_fields = registry['counter_fields'][self.child_model] - change_tracking_fields.add(f"{self.foreign_key_field.name}_id") + change_tracking_fields[f"{self.foreign_key_field.name}_id"] = counter_name self.connect() diff --git a/netbox/utilities/management/commands/calculate_cached_counts.py b/netbox/utilities/management/commands/calculate_cached_counts.py index d6f3bcbf2..62354797c 100644 --- a/netbox/utilities/management/commands/calculate_cached_counts.py +++ b/netbox/utilities/management/commands/calculate_cached_counts.py @@ -1,36 +1,52 @@ +from collections import defaultdict + from django.core.management.base import BaseCommand -from django.core.management.color import no_style +from django.db.models import Count, OuterRef, Subquery -from dcim.models import Device -from virtualization.models import VirtualMachine - - -def recalculate_device_counts(): - for device in Device.objects.all(): - device._console_port_count = device.consoleports.count() - device._console_server_port_count = device.consoleserverports.count() - device._interface_count = device.interfaces.count() - device._front_port_count = device.frontports.count() - device._rear_port_count = device.rearports.count() - device._device_bay_count = device.devicebays.count() - device._inventory_item_count = device.inventoryitems.count() - device._power_port_count = device.powerports.count() - device._power_outlet_count = device.poweroutlets.count() - device.save() - - -def recalculate_virtual_machine_counts(): - for vm in VirtualMachine.objects.all(): - vm._interface_count = vm.interfaces.count() - vm.save() +from netbox.registry import registry class Command(BaseCommand): - help = "Recalculate cached counts" + help = "Force a recalculation of all cached counter fields" + + @staticmethod + def collect_models(): + """ + Query the registry to find all models which have one or more counter fields. Return a mapping of counter fields + to related query names for each model. + """ + models = defaultdict(dict) + + for model, field_mappings in registry['counter_fields'].items(): + for field_name, counter_name in field_mappings.items(): + fk_field = model._meta.get_field(field_name) # Interface.device + parent_model = fk_field.related_model # Device + related_query_name = fk_field.related_query_name() # 'interfaces' + models[parent_model][counter_name] = related_query_name + + return models + + def update_counts(self, model, field_name, related_query): + """ + Perform a bulk update for the given model and counter field. For example, + + update_counts(Device, '_interface_count', 'interfaces') + + will effectively set + + Device.objects.update(_interface_count=Count('interfaces')) + """ + self.stdout.write(f'Updating {model.__name__} {field_name}...') + subquery = Subquery( + model.objects.filter(pk=OuterRef('pk')).annotate(_count=Count(related_query)).values('_count') + ) + return model.objects.update(**{ + field_name: subquery + }) def handle(self, *model_names, **options): - self.stdout.write('Recalculating device counts...') - recalculate_device_counts() - self.stdout.write('Recalculating virtual machine counts...') - recalculate_virtual_machine_counts() + for model, mappings in self.collect_models().items(): + for field_name, related_query in mappings.items(): + self.update_counts(model, field_name, related_query) + self.stdout.write(self.style.SUCCESS('Finished.'))