6347 Cache the number of each component type assigned to devices/VMs (#12632)

---------

Co-authored-by: Jeremy Stretch <jstretch@netboxlabs.com>
This commit is contained in:
Arthur Hanson 2023-07-25 20:39:05 +07:00 committed by GitHub
parent a4acb50edd
commit 149a496011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 623 additions and 35 deletions

View File

@ -8,6 +8,10 @@ The registry can be inspected by importing `registry` from `extras.registry`.
## Stores ## Stores
### `counter_fields`
A dictionary mapping of models to foreign keys with which cached counter fields are associated.
### `data_backends` ### `data_backends`
A dictionary mapping data backend types to their respective classes. These are used to interact with [remote data sources](../models/core/datasource.md). A dictionary mapping data backend types to their respective classes. These are used to interact with [remote data sources](../models/core/datasource.md).

View File

@ -669,14 +669,28 @@ class DeviceSerializer(NetBoxModelSerializer):
vc_position = serializers.IntegerField(allow_null=True, max_value=255, min_value=0, default=None) vc_position = serializers.IntegerField(allow_null=True, max_value=255, min_value=0, default=None)
config_template = NestedConfigTemplateSerializer(required=False, allow_null=True, default=None) config_template = NestedConfigTemplateSerializer(required=False, allow_null=True, default=None)
# Counter fields
console_port_count = serializers.IntegerField(read_only=True)
console_server_port_count = serializers.IntegerField(read_only=True)
power_port_count = serializers.IntegerField(read_only=True)
power_outlet_count = serializers.IntegerField(read_only=True)
interface_count = serializers.IntegerField(read_only=True)
front_port_count = serializers.IntegerField(read_only=True)
rear_port_count = serializers.IntegerField(read_only=True)
device_bay_count = serializers.IntegerField(read_only=True)
module_bay_count = serializers.IntegerField(read_only=True)
inventory_item_count = serializers.IntegerField(read_only=True)
class Meta: class Meta:
model = Device model = Device
fields = [ fields = [
'id', 'url', 'display', 'name', 'device_type', 'device_role', 'tenant', 'platform', 'serial', 'asset_tag', 'id', 'url', 'display', 'name', 'device_type', 'device_role', 'tenant', 'platform', 'serial', 'asset_tag',
'site', 'location', 'rack', 'position', 'face', 'latitude', 'longitude', 'parent_device', 'status', 'airflow', 'site', 'location', 'rack', 'position', 'face', 'latitude', 'longitude', 'parent_device', 'status',
'primary_ip', 'primary_ip4', 'primary_ip6', 'cluster', 'virtual_chassis', 'vc_position', 'vc_priority', 'airflow', 'primary_ip', 'primary_ip4', 'primary_ip6', 'cluster', 'virtual_chassis', 'vc_position',
'description', 'comments', 'config_template', 'local_context_data', 'tags', 'custom_fields', 'created', 'vc_priority', 'description', 'comments', 'config_template', 'local_context_data', 'tags', 'custom_fields',
'last_updated', 'created', 'last_updated', 'console_port_count', 'console_server_port_count', 'power_port_count',
'power_outlet_count', 'interface_count', 'front_port_count', 'rear_port_count', 'device_bay_count',
'module_bay_count', 'inventory_item_count',
] ]
@extend_schema_field(NestedDeviceSerializer) @extend_schema_field(NestedDeviceSerializer)
@ -700,7 +714,9 @@ class DeviceWithConfigContextSerializer(DeviceSerializer):
'site', 'location', 'rack', 'position', 'face', 'parent_device', 'status', 'airflow', 'primary_ip', 'site', 'location', 'rack', 'position', 'face', 'parent_device', 'status', 'airflow', 'primary_ip',
'primary_ip4', 'primary_ip6', 'cluster', 'virtual_chassis', 'vc_position', 'vc_priority', 'description', 'primary_ip4', 'primary_ip6', 'cluster', 'virtual_chassis', 'vc_position', 'vc_priority', 'description',
'comments', 'local_context_data', 'tags', 'custom_fields', 'config_context', 'config_template', 'comments', 'local_context_data', 'tags', 'custom_fields', 'config_context', 'config_template',
'created', 'last_updated', 'created', 'last_updated', 'console_port_count', 'console_server_port_count', 'power_port_count',
'power_outlet_count', 'interface_count', 'front_port_count', 'rear_port_count', 'device_bay_count',
'module_bay_count', 'inventory_item_count',
] ]
@extend_schema_field(serializers.JSONField(allow_null=True)) @extend_schema_field(serializers.JSONField(allow_null=True))

View File

@ -9,7 +9,8 @@ class DCIMConfig(AppConfig):
def ready(self): def ready(self):
from . import signals, search from . import signals, search
from .models import CableTermination from .models import CableTermination, Device
from utilities.counters import connect_counters
# Register denormalized fields # Register denormalized fields
denormalized.register(CableTermination, '_device', { denormalized.register(CableTermination, '_device', {
@ -24,3 +25,6 @@ class DCIMConfig(AppConfig):
denormalized.register(CableTermination, '_location', { denormalized.register(CableTermination, '_location', {
'_site': 'site', '_site': 'site',
}) })
# Register counters
connect_counters(Device)

View File

@ -0,0 +1,100 @@
from django.db import migrations
from django.db.models import Count
import utilities.fields
def recalculate_device_counts(apps, schema_editor):
Device = apps.get_model("dcim", "Device")
devices = list(Device.objects.all().annotate(
_console_port_count=Count('consoleports', distinct=True),
_console_server_port_count=Count('consoleserverports', distinct=True),
_power_port_count=Count('powerports', distinct=True),
_power_outlet_count=Count('poweroutlets', distinct=True),
_interface_count=Count('interfaces', distinct=True),
_front_port_count=Count('frontports', distinct=True),
_rear_port_count=Count('rearports', distinct=True),
_device_bay_count=Count('devicebays', distinct=True),
_module_bay_count=Count('modulebays', distinct=True),
_inventory_item_count=Count('inventoryitems', distinct=True),
))
for device in devices:
device.console_port_count = device._console_port_count
device.console_server_port_count = device._console_server_port_count
device.power_port_count = device._power_port_count
device.power_outlet_count = device._power_outlet_count
device.interface_count = device._interface_count
device.front_port_count = device._front_port_count
device.rear_port_count = device._rear_port_count
device.device_bay_count = device._device_bay_count
device.module_bay_count = device._module_bay_count
device.inventory_item_count = device._inventory_item_count
Device.objects.bulk_update(devices, [
'console_port_count', 'console_server_port_count', 'power_port_count', 'power_outlet_count', 'interface_count',
'front_port_count', 'rear_port_count', 'device_bay_count', 'module_bay_count', 'inventory_item_count',
])
class Migration(migrations.Migration):
dependencies = [
('dcim', '0174_rack_starting_unit'),
]
operations = [
migrations.AddField(
model_name='device',
name='console_port_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.ConsolePort'),
),
migrations.AddField(
model_name='device',
name='console_server_port_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.ConsoleServerPort'),
),
migrations.AddField(
model_name='device',
name='power_port_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.PowerPort'),
),
migrations.AddField(
model_name='device',
name='power_outlet_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.PowerOutlet'),
),
migrations.AddField(
model_name='device',
name='interface_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.Interface'),
),
migrations.AddField(
model_name='device',
name='front_port_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.FrontPort'),
),
migrations.AddField(
model_name='device',
name='rear_port_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.RearPort'),
),
migrations.AddField(
model_name='device',
name='device_bay_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.DeviceBay'),
),
migrations.AddField(
model_name='device',
name='module_bay_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.ModuleBay'),
),
migrations.AddField(
model_name='device',
name='inventory_item_count',
field=utilities.fields.CounterCacheField(default=0, to_field='device', to_model='dcim.InventoryItem'),
),
migrations.RunPython(
recalculate_device_counts,
reverse_code=migrations.RunPython.noop
),
]

View File

@ -19,6 +19,7 @@ from utilities.fields import ColorField, NaturalOrderingField
from utilities.mptt import TreeManager from utilities.mptt import TreeManager
from utilities.ordering import naturalize_interface from utilities.ordering import naturalize_interface
from utilities.query_functions import CollateAsChar from utilities.query_functions import CollateAsChar
from utilities.tracking import TrackingModelMixin
from wireless.choices import * from wireless.choices import *
from wireless.utils import get_channel_attr from wireless.utils import get_channel_attr
@ -269,7 +270,7 @@ class PathEndpoint(models.Model):
# Console components # Console components
# #
class ConsolePort(ModularComponentModel, CabledObjectModel, PathEndpoint): class ConsolePort(ModularComponentModel, CabledObjectModel, PathEndpoint, TrackingModelMixin):
""" """
A physical console port within a Device. ConsolePorts connect to ConsoleServerPorts. A physical console port within a Device. ConsolePorts connect to ConsoleServerPorts.
""" """
@ -292,7 +293,7 @@ class ConsolePort(ModularComponentModel, CabledObjectModel, PathEndpoint):
return reverse('dcim:consoleport', kwargs={'pk': self.pk}) return reverse('dcim:consoleport', kwargs={'pk': self.pk})
class ConsoleServerPort(ModularComponentModel, CabledObjectModel, PathEndpoint): class ConsoleServerPort(ModularComponentModel, CabledObjectModel, PathEndpoint, TrackingModelMixin):
""" """
A physical port within a Device (typically a designated console server) which provides access to ConsolePorts. A physical port within a Device (typically a designated console server) which provides access to ConsolePorts.
""" """
@ -319,7 +320,7 @@ class ConsoleServerPort(ModularComponentModel, CabledObjectModel, PathEndpoint):
# Power components # Power components
# #
class PowerPort(ModularComponentModel, CabledObjectModel, PathEndpoint): class PowerPort(ModularComponentModel, CabledObjectModel, PathEndpoint, TrackingModelMixin):
""" """
A physical power supply (intake) port within a Device. PowerPorts connect to PowerOutlets. A physical power supply (intake) port within a Device. PowerPorts connect to PowerOutlets.
""" """
@ -428,7 +429,7 @@ class PowerPort(ModularComponentModel, CabledObjectModel, PathEndpoint):
} }
class PowerOutlet(ModularComponentModel, CabledObjectModel, PathEndpoint): class PowerOutlet(ModularComponentModel, CabledObjectModel, PathEndpoint, TrackingModelMixin):
""" """
A physical power outlet (output) within a Device which provides power to a PowerPort. A physical power outlet (output) within a Device which provides power to a PowerPort.
""" """
@ -537,7 +538,7 @@ class BaseInterface(models.Model):
return self.fhrp_group_assignments.count() return self.fhrp_group_assignments.count()
class Interface(ModularComponentModel, BaseInterface, CabledObjectModel, PathEndpoint): class Interface(ModularComponentModel, BaseInterface, CabledObjectModel, PathEndpoint, TrackingModelMixin):
""" """
A network interface within a Device. A physical Interface can connect to exactly one other Interface. A network interface within a Device. A physical Interface can connect to exactly one other Interface.
""" """
@ -888,7 +889,7 @@ class Interface(ModularComponentModel, BaseInterface, CabledObjectModel, PathEnd
# Pass-through ports # Pass-through ports
# #
class FrontPort(ModularComponentModel, CabledObjectModel): class FrontPort(ModularComponentModel, CabledObjectModel, TrackingModelMixin):
""" """
A pass-through port on the front of a Device. A pass-through port on the front of a Device.
""" """
@ -949,7 +950,7 @@ class FrontPort(ModularComponentModel, CabledObjectModel):
}) })
class RearPort(ModularComponentModel, CabledObjectModel): class RearPort(ModularComponentModel, CabledObjectModel, TrackingModelMixin):
""" """
A pass-through port on the rear of a Device. A pass-through port on the rear of a Device.
""" """
@ -990,7 +991,7 @@ class RearPort(ModularComponentModel, CabledObjectModel):
# Bays # Bays
# #
class ModuleBay(ComponentModel): class ModuleBay(ComponentModel, TrackingModelMixin):
""" """
An empty space within a Device which can house a child device An empty space within a Device which can house a child device
""" """
@ -1006,7 +1007,7 @@ class ModuleBay(ComponentModel):
return reverse('dcim:modulebay', kwargs={'pk': self.pk}) return reverse('dcim:modulebay', kwargs={'pk': self.pk})
class DeviceBay(ComponentModel): class DeviceBay(ComponentModel, TrackingModelMixin):
""" """
An empty space within a Device which can house a child device An empty space within a Device which can house a child device
""" """
@ -1064,7 +1065,7 @@ class InventoryItemRole(OrganizationalModel):
return reverse('dcim:inventoryitemrole', args=[self.pk]) return reverse('dcim:inventoryitemrole', args=[self.pk])
class InventoryItem(MPTTModel, ComponentModel): class InventoryItem(MPTTModel, ComponentModel, TrackingModelMixin):
""" """
An InventoryItem represents a serialized piece of hardware within a Device, such as a line card or power supply. An InventoryItem represents a serialized piece of hardware within a Device, such as a line card or power supply.
InventoryItems are used only for inventory purposes. InventoryItems are used only for inventory purposes.

View File

@ -21,7 +21,7 @@ from extras.querysets import ConfigContextModelQuerySet
from netbox.config import ConfigItem from netbox.config import ConfigItem
from netbox.models import OrganizationalModel, PrimaryModel from netbox.models import OrganizationalModel, PrimaryModel
from utilities.choices import ColorChoices from utilities.choices import ColorChoices
from utilities.fields import ColorField, NaturalOrderingField from utilities.fields import ColorField, CounterCacheField, NaturalOrderingField
from .device_components import * from .device_components import *
from .mixins import WeightMixin from .mixins import WeightMixin
@ -639,6 +639,48 @@ class Device(PrimaryModel, ConfigContextModel):
help_text=_("GPS coordinate in decimal format (xx.yyyyyy)") help_text=_("GPS coordinate in decimal format (xx.yyyyyy)")
) )
# Counter fields
console_port_count = CounterCacheField(
to_model='dcim.ConsolePort',
to_field='device'
)
console_server_port_count = CounterCacheField(
to_model='dcim.ConsoleServerPort',
to_field='device'
)
power_port_count = CounterCacheField(
to_model='dcim.PowerPort',
to_field='device'
)
power_outlet_count = CounterCacheField(
to_model='dcim.PowerOutlet',
to_field='device'
)
interface_count = CounterCacheField(
to_model='dcim.Interface',
to_field='device'
)
front_port_count = CounterCacheField(
to_model='dcim.FrontPort',
to_field='device'
)
rear_port_count = CounterCacheField(
to_model='dcim.RearPort',
to_field='device'
)
device_bay_count = CounterCacheField(
to_model='dcim.DeviceBay',
to_field='device'
)
module_bay_count = CounterCacheField(
to_model='dcim.ModuleBay',
to_field='device'
)
inventory_item_count = CounterCacheField(
to_model='dcim.InventoryItem',
to_field='device'
)
# Generic relations # Generic relations
contacts = GenericRelation( contacts = GenericRelation(
to='tenancy.ContactAssignment' to='tenancy.ContactAssignment'

View File

@ -1,10 +1,10 @@
import django_tables2 as tables import django_tables2 as tables
from dcim import models
from django_tables2.utils import Accessor from django_tables2.utils import Accessor
from tenancy.tables import ContactsColumnMixin, TenancyColumnsMixin from django.utils.translation import gettext as _
from dcim import models
from netbox.tables import NetBoxTable, columns from netbox.tables import NetBoxTable, columns
from tenancy.tables import ContactsColumnMixin, TenancyColumnsMixin
from .template_code import * from .template_code import *
__all__ = ( __all__ = (
@ -230,6 +230,36 @@ class DeviceTable(TenancyColumnsMixin, ContactsColumnMixin, NetBoxTable):
tags = columns.TagColumn( tags = columns.TagColumn(
url_name='dcim:device_list' url_name='dcim:device_list'
) )
console_port_count = tables.Column(
verbose_name=_('Console ports')
)
console_server_port_count = tables.Column(
verbose_name=_('Console server ports')
)
power_port_count = tables.Column(
verbose_name=_('Power ports')
)
power_outlet_count = tables.Column(
verbose_name=_('Power outlets')
)
interface_count = tables.Column(
verbose_name=_('Interfaces')
)
front_port_count = tables.Column(
verbose_name=_('Front ports')
)
rear_port_count = tables.Column(
verbose_name=_('Rear ports')
)
device_bay_count = tables.Column(
verbose_name=_('Device bays')
)
module_bay_count = tables.Column(
verbose_name=_('Module bays')
)
inventory_item_count = tables.Column(
verbose_name=_('Inventory items')
)
class Meta(NetBoxTable.Meta): class Meta(NetBoxTable.Meta):
model = models.Device model = models.Device

View File

@ -1876,7 +1876,7 @@ class DeviceConsolePortsView(DeviceComponentsView):
template_name = 'dcim/device/consoleports.html', template_name = 'dcim/device/consoleports.html',
tab = ViewTab( tab = ViewTab(
label=_('Console Ports'), label=_('Console Ports'),
badge=lambda obj: obj.consoleports.count(), badge=lambda obj: obj.console_port_count,
permission='dcim.view_consoleport', permission='dcim.view_consoleport',
weight=550, weight=550,
hide_if_empty=True hide_if_empty=True
@ -1891,7 +1891,7 @@ class DeviceConsoleServerPortsView(DeviceComponentsView):
template_name = 'dcim/device/consoleserverports.html' template_name = 'dcim/device/consoleserverports.html'
tab = ViewTab( tab = ViewTab(
label=_('Console Server Ports'), label=_('Console Server Ports'),
badge=lambda obj: obj.consoleserverports.count(), badge=lambda obj: obj.console_server_port_count,
permission='dcim.view_consoleserverport', permission='dcim.view_consoleserverport',
weight=560, weight=560,
hide_if_empty=True hide_if_empty=True
@ -1906,7 +1906,7 @@ class DevicePowerPortsView(DeviceComponentsView):
template_name = 'dcim/device/powerports.html' template_name = 'dcim/device/powerports.html'
tab = ViewTab( tab = ViewTab(
label=_('Power Ports'), label=_('Power Ports'),
badge=lambda obj: obj.powerports.count(), badge=lambda obj: obj.power_port_count,
permission='dcim.view_powerport', permission='dcim.view_powerport',
weight=570, weight=570,
hide_if_empty=True hide_if_empty=True
@ -1921,7 +1921,7 @@ class DevicePowerOutletsView(DeviceComponentsView):
template_name = 'dcim/device/poweroutlets.html' template_name = 'dcim/device/poweroutlets.html'
tab = ViewTab( tab = ViewTab(
label=_('Power Outlets'), label=_('Power Outlets'),
badge=lambda obj: obj.poweroutlets.count(), badge=lambda obj: obj.power_outlet_count,
permission='dcim.view_poweroutlet', permission='dcim.view_poweroutlet',
weight=580, weight=580,
hide_if_empty=True hide_if_empty=True
@ -1957,7 +1957,7 @@ class DeviceFrontPortsView(DeviceComponentsView):
template_name = 'dcim/device/frontports.html' template_name = 'dcim/device/frontports.html'
tab = ViewTab( tab = ViewTab(
label=_('Front Ports'), label=_('Front Ports'),
badge=lambda obj: obj.frontports.count(), badge=lambda obj: obj.front_port_count,
permission='dcim.view_frontport', permission='dcim.view_frontport',
weight=530, weight=530,
hide_if_empty=True hide_if_empty=True
@ -1972,7 +1972,7 @@ class DeviceRearPortsView(DeviceComponentsView):
template_name = 'dcim/device/rearports.html' template_name = 'dcim/device/rearports.html'
tab = ViewTab( tab = ViewTab(
label=_('Rear Ports'), label=_('Rear Ports'),
badge=lambda obj: obj.rearports.count(), badge=lambda obj: obj.rear_port_count,
permission='dcim.view_rearport', permission='dcim.view_rearport',
weight=540, weight=540,
hide_if_empty=True hide_if_empty=True
@ -1987,7 +1987,7 @@ class DeviceModuleBaysView(DeviceComponentsView):
template_name = 'dcim/device/modulebays.html' template_name = 'dcim/device/modulebays.html'
tab = ViewTab( tab = ViewTab(
label=_('Module Bays'), label=_('Module Bays'),
badge=lambda obj: obj.modulebays.count(), badge=lambda obj: obj.module_bay_count,
permission='dcim.view_modulebay', permission='dcim.view_modulebay',
weight=510, weight=510,
hide_if_empty=True hide_if_empty=True
@ -2002,7 +2002,7 @@ class DeviceDeviceBaysView(DeviceComponentsView):
template_name = 'dcim/device/devicebays.html' template_name = 'dcim/device/devicebays.html'
tab = ViewTab( tab = ViewTab(
label=_('Device Bays'), label=_('Device Bays'),
badge=lambda obj: obj.devicebays.count(), badge=lambda obj: obj.device_bay_count,
permission='dcim.view_devicebay', permission='dcim.view_devicebay',
weight=500, weight=500,
hide_if_empty=True hide_if_empty=True
@ -2017,7 +2017,7 @@ class DeviceInventoryView(DeviceComponentsView):
template_name = 'dcim/device/inventory.html' template_name = 'dcim/device/inventory.html'
tab = ViewTab( tab = ViewTab(
label=_('Inventory Items'), label=_('Inventory Items'),
badge=lambda obj: obj.inventoryitems.count(), badge=lambda obj: obj.inventory_item_count,
permission='dcim.view_inventoryitem', permission='dcim.view_inventoryitem',
weight=590, weight=590,
hide_if_empty=True hide_if_empty=True

View File

@ -8,6 +8,7 @@ from netbox.models.features import *
from utilities.mptt import TreeManager from utilities.mptt import TreeManager
from utilities.querysets import RestrictedQuerySet from utilities.querysets import RestrictedQuerySet
__all__ = ( __all__ = (
'ChangeLoggedModel', 'ChangeLoggedModel',
'NestedGroupModel', 'NestedGroupModel',

View File

@ -21,6 +21,7 @@ class Registry(dict):
# Initialize the global registry # Initialize the global registry
registry = Registry({ registry = Registry({
'counter_fields': collections.defaultdict(dict),
'data_backends': dict(), 'data_backends': dict(),
'denormalized_fields': collections.defaultdict(list), 'denormalized_fields': collections.defaultdict(list),
'model_features': dict(), 'model_features': dict(),

View File

@ -0,0 +1,93 @@
from django.apps import apps
from django.db.models import F
from django.db.models.signals import post_delete, post_save
from netbox.registry import registry
from .fields import CounterCacheField
def get_counters_for_model(model):
"""
Return field mappings for all counters registered to the given model.
"""
return registry['counter_fields'][model].items()
def update_counter(model, pk, counter_name, value):
"""
Increment or decrement a counter field on an object identified by its model and primary key (PK). Positive values
will increment; negative values will decrement.
"""
model.objects.filter(pk=pk).update(
**{counter_name: F(counter_name) + value}
)
#
# Signal handlers
#
def post_save_receiver(sender, instance, **kwargs):
"""
Update counter fields on related objects when a TrackingModelMixin subclass is created or modified.
"""
for field_name, counter_name in get_counters_for_model(sender):
parent_model = sender._meta.get_field(field_name).related_model
new_pk = getattr(instance, field_name, None)
old_pk = instance.tracker.get(field_name) if field_name in instance.tracker else None
# Update the counters on the old and/or new parents as needed
if old_pk is not None:
update_counter(parent_model, old_pk, counter_name, -1)
if new_pk is not None:
update_counter(parent_model, new_pk, counter_name, 1)
def post_delete_receiver(sender, instance, **kwargs):
"""
Update counter fields on related objects when a TrackingModelMixin subclass is deleted.
"""
for field_name, counter_name in get_counters_for_model(sender):
parent_model = sender._meta.get_field(field_name).related_model
parent_pk = getattr(instance, field_name, None)
# Decrement the parent's counter by one
if parent_pk is not None:
update_counter(parent_model, parent_pk, counter_name, -1)
#
# Registration
#
def connect_counters(*models):
"""
Register counter fields and connect post_save & post_delete signal handlers for the affected models.
"""
for model in models:
# Find all CounterCacheFields on the model
counter_fields = [
field for field in model._meta.get_fields() if type(field) is CounterCacheField
]
for field in counter_fields:
to_model = apps.get_model(field.to_model_name)
# Register the counter in the registry
change_tracking_fields = registry['counter_fields'][to_model]
change_tracking_fields[f"{field.to_field_name}_id"] = field.name
# Connect the post_save and post_delete handlers
post_save.connect(
post_save_receiver,
sender=to_model,
weak=False,
dispatch_uid=f'{model._meta.label}.{field.name}'
)
post_delete.connect(
post_delete_receiver,
sender=to_model,
weak=False,
dispatch_uid=f'{model._meta.label}.{field.name}'
)

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _
from utilities.ordering import naturalize from utilities.ordering import naturalize
from .forms.widgets import ColorSelect from .forms.widgets import ColorSelect
@ -9,6 +10,7 @@ from .validators import ColorValidator
__all__ = ( __all__ = (
'ColorField', 'ColorField',
'CounterCacheField',
'NaturalOrderingField', 'NaturalOrderingField',
'NullableCharField', 'NullableCharField',
'RestrictedGenericForeignKey', 'RestrictedGenericForeignKey',
@ -143,3 +145,43 @@ class RestrictedGenericForeignKey(GenericForeignKey):
self.name, self.name,
False, False,
) )
class CounterCacheField(models.BigIntegerField):
"""
Counter field to keep track of related model counts.
"""
def __init__(self, to_model, to_field, *args, **kwargs):
if not isinstance(to_model, str):
raise TypeError(
_("%s(%r) is invalid. to_model parameter to CounterCacheField must be "
"a string in the format 'app.model'")
% (
self.__class__.__name__,
to_model,
)
)
if not isinstance(to_field, str):
raise TypeError(
_("%s(%r) is invalid. to_field parameter to CounterCacheField must be "
"a string in the format 'field'")
% (
self.__class__.__name__,
to_field,
)
)
self.to_model_name = to_model
self.to_field_name = to_field
kwargs['default'] = kwargs.get('default', 0)
kwargs['editable'] = False
super().__init__(*args, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs["to_model"] = self.to_model_name
kwargs["to_field"] = self.to_field_name
return name, path, args, kwargs

View File

View File

@ -0,0 +1,52 @@
from collections import defaultdict
from django.core.management.base import BaseCommand
from django.db.models import Count, OuterRef, Subquery
from netbox.registry import registry
class Command(BaseCommand):
help = "Force a recalculation of all cached counter fields"
@staticmethod
def collect_models():
"""
Query the registry to find all models which have one or more counter fields. Return a mapping of counter fields
to related query names for each model.
"""
models = defaultdict(dict)
for model, field_mappings in registry['counter_fields'].items():
for field_name, counter_name in field_mappings.items():
fk_field = model._meta.get_field(field_name) # Interface.device
parent_model = fk_field.related_model # Device
related_query_name = fk_field.related_query_name() # 'interfaces'
models[parent_model][counter_name] = related_query_name
return models
def update_counts(self, model, field_name, related_query):
"""
Perform a bulk update for the given model and counter field. For example,
update_counts(Device, '_interface_count', 'interfaces')
will effectively set
Device.objects.update(_interface_count=Count('interfaces'))
"""
self.stdout.write(f'Updating {model.__name__} {field_name}...')
subquery = Subquery(
model.objects.filter(pk=OuterRef('pk')).annotate(_count=Count(related_query)).values('_count')
)
return model.objects.update(**{
field_name: subquery
})
def handle(self, *model_names, **options):
for model, mappings in self.collect_models().items():
for field_name, related_query in mappings.items():
self.update_counts(model, field_name, related_query)
self.stdout.write(self.style.SUCCESS('Finished.'))

View File

@ -0,0 +1,69 @@
from django.test import TestCase
from dcim.models import *
from utilities.testing.utils import create_test_device
class CountersTest(TestCase):
"""
Validate the operation of dict_to_filter_params().
"""
@classmethod
def setUpTestData(cls):
# Create devices
device1 = create_test_device('Device 1')
device2 = create_test_device('Device 2')
# Create interfaces
Interface.objects.create(device=device1, name='Interface 1')
Interface.objects.create(device=device1, name='Interface 2')
Interface.objects.create(device=device2, name='Interface 3')
Interface.objects.create(device=device2, name='Interface 4')
def test_interface_count_creation(self):
"""
When a tracked object (Interface) is added the tracking counter should be updated.
"""
device1, device2 = Device.objects.all()
self.assertEqual(device1.interface_count, 2)
self.assertEqual(device2.interface_count, 2)
Interface.objects.create(device=device1, name='Interface 5')
Interface.objects.create(device=device2, name='Interface 6')
device1.refresh_from_db()
device2.refresh_from_db()
self.assertEqual(device1.interface_count, 3)
self.assertEqual(device2.interface_count, 3)
def test_interface_count_deletion(self):
"""
When a tracked object (Interface) is deleted the tracking counter should be updated.
"""
device1, device2 = Device.objects.all()
self.assertEqual(device1.interface_count, 2)
self.assertEqual(device2.interface_count, 2)
Interface.objects.get(name='Interface 1').delete()
Interface.objects.get(name='Interface 3').delete()
device1.refresh_from_db()
device2.refresh_from_db()
self.assertEqual(device1.interface_count, 1)
self.assertEqual(device2.interface_count, 1)
def test_interface_count_move(self):
"""
When a tracked object (Interface) is moved the tracking counter should be updated.
"""
device1, device2 = Device.objects.all()
self.assertEqual(device1.interface_count, 2)
self.assertEqual(device2.interface_count, 2)
interface1 = Interface.objects.get(name='Interface 1')
interface1.device = device2
interface1.save()
device1.refresh_from_db()
device2.refresh_from_db()
self.assertEqual(device1.interface_count, 1)
self.assertEqual(device2.interface_count, 3)

View File

@ -0,0 +1,78 @@
from django.db.models.query_utils import DeferredAttribute
from netbox.registry import registry
class Tracker:
"""
An ephemeral instance employed to record which tracked fields on an instance have been modified.
"""
def __init__(self):
self._changed_fields = {}
def __contains__(self, item):
return item in self._changed_fields
def set(self, name, value):
"""
Mark an attribute as having been changed and record its original value.
"""
self._changed_fields[name] = value
def get(self, name):
"""
Return the original value of a changed field. Raises KeyError if name is not found.
"""
return self._changed_fields[name]
def clear(self, *names):
"""
Clear any fields that were recorded as having been changed.
"""
for name in names:
self._changed_fields.pop(name, None)
else:
self._changed_fields = {}
class TrackingModelMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Mark the instance as initialized, to enable our custom __setattr__()
self._initialized = True
@property
def tracker(self):
"""
Return the Tracker instance for this instance, first creating it if necessary.
"""
if not hasattr(self._state, "_tracker"):
self._state._tracker = Tracker()
return self._state._tracker
def save(self, *args, **kwargs):
super().save(*args, **kwargs)
# Clear any tracked fields now that changes have been saved
update_fields = kwargs.get('update_fields', [])
self.tracker.clear(*update_fields)
def __setattr__(self, name, value):
if hasattr(self, "_initialized"):
# Record any changes to a tracked field
if name in registry['counter_fields'][self.__class__]:
if name not in self.tracker:
# The attribute has been created or changed
if name in self.__dict__:
old_value = getattr(self, name)
if value != old_value:
self.tracker.set(name, old_value)
else:
self.tracker.set(name, DeferredAttribute)
elif value == self.tracker.get(name):
# A previously changed attribute has been restored
self.tracker.clear(name)
super().__setattr__(name, value)

View File

@ -80,12 +80,15 @@ class VirtualMachineSerializer(NetBoxModelSerializer):
primary_ip4 = NestedIPAddressSerializer(required=False, allow_null=True) primary_ip4 = NestedIPAddressSerializer(required=False, allow_null=True)
primary_ip6 = NestedIPAddressSerializer(required=False, allow_null=True) primary_ip6 = NestedIPAddressSerializer(required=False, allow_null=True)
# Counter fields
interface_count = serializers.IntegerField(read_only=True)
class Meta: class Meta:
model = VirtualMachine model = VirtualMachine
fields = [ fields = [
'id', 'url', 'display', 'name', 'status', 'site', 'cluster', 'device', 'role', 'tenant', 'platform', 'id', 'url', 'display', 'name', 'status', 'site', 'cluster', 'device', 'role', 'tenant', 'platform',
'primary_ip', 'primary_ip4', 'primary_ip6', 'vcpus', 'memory', 'disk', 'description', 'comments', 'primary_ip', 'primary_ip4', 'primary_ip6', 'vcpus', 'memory', 'disk', 'description', 'comments',
'local_context_data', 'tags', 'custom_fields', 'created', 'last_updated', 'local_context_data', 'tags', 'custom_fields', 'created', 'last_updated', 'interface_count',
] ]
validators = [] validators = []
@ -98,6 +101,7 @@ class VirtualMachineWithConfigContextSerializer(VirtualMachineSerializer):
'id', 'url', 'display', 'name', 'status', 'site', 'cluster', 'device', 'role', 'tenant', 'platform', 'id', 'url', 'display', 'name', 'status', 'site', 'cluster', 'device', 'role', 'tenant', 'platform',
'primary_ip', 'primary_ip4', 'primary_ip6', 'vcpus', 'memory', 'disk', 'description', 'comments', 'primary_ip', 'primary_ip4', 'primary_ip6', 'vcpus', 'memory', 'disk', 'description', 'comments',
'local_context_data', 'tags', 'custom_fields', 'config_context', 'created', 'last_updated', 'local_context_data', 'tags', 'custom_fields', 'config_context', 'created', 'last_updated',
'interface_count',
] ]
@extend_schema_field(serializers.JSONField(allow_null=True)) @extend_schema_field(serializers.JSONField(allow_null=True))

View File

@ -6,3 +6,8 @@ class VirtualizationConfig(AppConfig):
def ready(self): def ready(self):
from . import search from . import search
from .models import VirtualMachine
from utilities.counters import connect_counters
# Register counters
connect_counters(VirtualMachine)

View File

@ -0,0 +1,35 @@
from django.db import migrations
from django.db.models import Count
import utilities.fields
def populate_virtualmachine_counts(apps, schema_editor):
VirtualMachine = apps.get_model('virtualization', 'VirtualMachine')
vms = list(VirtualMachine.objects.annotate(_interface_count=Count('interfaces', distinct=True)))
for vm in vms:
vm.interface_count = vm._interface_count
VirtualMachine.objects.bulk_update(vms, ['interface_count'])
class Migration(migrations.Migration):
dependencies = [
('virtualization', '0034_standardize_description_comments'),
]
operations = [
migrations.AddField(
model_name='virtualmachine',
name='interface_count',
field=utilities.fields.CounterCacheField(
default=0, to_field='virtual_machine', to_model='virtualization.VMInterface'
),
),
migrations.RunPython(
code=populate_virtualmachine_counts,
reverse_code=migrations.RunPython.noop
),
]

View File

@ -11,9 +11,10 @@ from extras.models import ConfigContextModel
from extras.querysets import ConfigContextModelQuerySet from extras.querysets import ConfigContextModelQuerySet
from netbox.config import get_config from netbox.config import get_config
from netbox.models import NetBoxModel, PrimaryModel from netbox.models import NetBoxModel, PrimaryModel
from utilities.fields import NaturalOrderingField from utilities.fields import CounterCacheField, NaturalOrderingField
from utilities.ordering import naturalize_interface from utilities.ordering import naturalize_interface
from utilities.query_functions import CollateAsChar from utilities.query_functions import CollateAsChar
from utilities.tracking import TrackingModelMixin
from virtualization.choices import * from virtualization.choices import *
__all__ = ( __all__ = (
@ -120,6 +121,12 @@ class VirtualMachine(PrimaryModel, ConfigContextModel):
verbose_name='Disk (GB)' verbose_name='Disk (GB)'
) )
# Counter fields
interface_count = CounterCacheField(
to_model='virtualization.VMInterface',
to_field='virtual_machine'
)
# Generic relation # Generic relation
contacts = GenericRelation( contacts = GenericRelation(
to='tenancy.ContactAssignment' to='tenancy.ContactAssignment'
@ -222,7 +229,7 @@ class VirtualMachine(PrimaryModel, ConfigContextModel):
return None return None
class VMInterface(NetBoxModel, BaseInterface): class VMInterface(NetBoxModel, BaseInterface, TrackingModelMixin):
virtual_machine = models.ForeignKey( virtual_machine = models.ForeignKey(
to='virtualization.VirtualMachine', to='virtualization.VirtualMachine',
on_delete=models.CASCADE, on_delete=models.CASCADE,

View File

@ -1,10 +1,11 @@
import django_tables2 as tables import django_tables2 as tables
from django.utils.translation import gettext as _
from dcim.tables.devices import BaseInterfaceTable from dcim.tables.devices import BaseInterfaceTable
from netbox.tables import NetBoxTable, columns
from tenancy.tables import ContactsColumnMixin, TenancyColumnsMixin from tenancy.tables import ContactsColumnMixin, TenancyColumnsMixin
from virtualization.models import VirtualMachine, VMInterface from virtualization.models import VirtualMachine, VMInterface
from netbox.tables import NetBoxTable, columns
__all__ = ( __all__ = (
'VirtualMachineTable', 'VirtualMachineTable',
'VirtualMachineVMInterfaceTable', 'VirtualMachineVMInterfaceTable',
@ -70,6 +71,9 @@ class VirtualMachineTable(TenancyColumnsMixin, ContactsColumnMixin, NetBoxTable)
tags = columns.TagColumn( tags = columns.TagColumn(
url_name='virtualization:virtualmachine_list' url_name='virtualization:virtualmachine_list'
) )
interface_count = tables.Column(
verbose_name=_('Interfaces')
)
class Meta(NetBoxTable.Meta): class Meta(NetBoxTable.Meta):
model = VirtualMachine model = VirtualMachine

View File

@ -349,7 +349,7 @@ class VirtualMachineInterfacesView(generic.ObjectChildrenView):
template_name = 'virtualization/virtualmachine/interfaces.html' template_name = 'virtualization/virtualmachine/interfaces.html'
tab = ViewTab( tab = ViewTab(
label=_('Interfaces'), label=_('Interfaces'),
badge=lambda obj: obj.interfaces.count(), badge=lambda obj: obj.interface_count,
permission='virtualization.view_vminterface', permission='virtualization.view_vminterface',
weight=500 weight=500
) )