diff --git a/netbox/dcim/api/views.py b/netbox/dcim/api/views.py index d7dbbef91..5ca851bca 100644 --- a/netbox/dcim/api/views.py +++ b/netbox/dcim/api/views.py @@ -1,3 +1,4 @@ +from django.contrib.contenttypes.prefetch import GenericPrefetch from django.http import Http404, HttpResponse from django.shortcuts import get_object_or_404 from drf_spectacular.types import OpenApiTypes @@ -442,7 +443,18 @@ class PowerOutletViewSet(PathEndpointMixin, NetBoxModelViewSet): class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet): queryset = Interface.objects.prefetch_related( - '_path', 'cable__terminations', + GenericPrefetch( + "cable__terminations__termination", + [ + Interface.objects.select_related("device", "cable"), + ], + ), + GenericPrefetch( + "_path__path_objects", + [ + Interface.objects.select_related("device", "cable"), + ], + ), 'l2vpn_terminations', # Referenced by InterfaceSerializer.l2vpn_termination 'ip_addresses', # Referenced by Interface.count_ipaddresses() 'fhrp_group_assignments', # Referenced by Interface.count_fhrp_groups() diff --git a/netbox/dcim/models/cables.py b/netbox/dcim/models/cables.py index 81a742fe6..5d707375f 100644 --- a/netbox/dcim/models/cables.py +++ b/netbox/dcim/models/cables.py @@ -1,5 +1,4 @@ import itertools -from collections import defaultdict from django.contrib.contenttypes.fields import GenericForeignKey from django.core.exceptions import ValidationError @@ -16,7 +15,7 @@ from dcim.utils import decompile_path_node, object_to_path_node from netbox.models import ChangeLoggedModel, PrimaryModel from utilities.conversion import to_meters from utilities.exceptions import AbortRequest -from utilities.fields import ColorField +from utilities.fields import ColorField, GenericArrayForeignKey from utilities.querysets import RestrictedQuerySet from wireless.models import WirelessLink from .device_components import FrontPort, RearPort, PathEndpoint @@ -494,13 +493,16 @@ class CablePath(models.Model): return ObjectType.objects.get_for_id(ct_id) @property - def path_objects(self): - """ - Cache and return the complete path as lists of objects, derived from their annotation within the path. - """ - if not hasattr(self, '_path_objects'): - self._path_objects = self._get_path() - return self._path_objects + def _path_decompiled(self): + res = [] + for step in self.path: + nodes = [] + for node in step: + nodes.append(decompile_path_node(node)) + res.append(nodes) + return res + + path_objects = GenericArrayForeignKey("_path_decompiled") @property def origins(self): @@ -757,42 +759,6 @@ class CablePath(models.Model): self.delete() retrace.alters_data = True - def _get_path(self): - """ - Return the path as a list of prefetched objects. - """ - # Compile a list of IDs to prefetch for each type of model in the path - to_prefetch = defaultdict(list) - for node in self._nodes: - ct_id, object_id = decompile_path_node(node) - to_prefetch[ct_id].append(object_id) - - # Prefetch path objects using one query per model type. Prefetch related devices where appropriate. - prefetched = {} - for ct_id, object_ids in to_prefetch.items(): - model_class = ObjectType.objects.get_for_id(ct_id).model_class() - queryset = model_class.objects.filter(pk__in=object_ids) - if hasattr(model_class, 'device'): - queryset = queryset.prefetch_related('device') - prefetched[ct_id] = { - obj.id: obj for obj in queryset - } - - # Replicate the path using the prefetched objects. - path = [] - for step in self.path: - nodes = [] - for node in step: - ct_id, object_id = decompile_path_node(node) - try: - nodes.append(prefetched[ct_id][object_id]) - except KeyError: - # Ignore stale (deleted) object IDs - pass - path.append(nodes) - - return path - def get_cable_ids(self): """ Return all Cable IDs within the path. diff --git a/netbox/dcim/models/device_components.py b/netbox/dcim/models/device_components.py index 8a8e8f4cc..632121dc2 100644 --- a/netbox/dcim/models/device_components.py +++ b/netbox/dcim/models/device_components.py @@ -184,8 +184,11 @@ class CabledObjectModel(models.Model): @cached_property def link_peers(self): if self.cable: - peers = self.cable.terminations.exclude(cable_end=self.cable_end).prefetch_related('termination') - return [peer.termination for peer in peers] + return [ + peer.termination + for peer in self.cable.terminations.all() + if peer.cable_end != self.cable_end + ] return [] @property diff --git a/netbox/utilities/fields.py b/netbox/utilities/fields.py index 1d16a1d3f..05f61a147 100644 --- a/netbox/utilities/fields.py +++ b/netbox/utilities/fields.py @@ -1,7 +1,11 @@ from collections import defaultdict from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ObjectDoesNotExist from django.db import models +from django.db.models.fields.mixins import FieldCacheMixin +from django.utils.functional import cached_property from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ @@ -11,6 +15,7 @@ from .validators import ColorValidator __all__ = ( 'ColorField', 'CounterCacheField', + 'GenericArrayForeignKey', 'NaturalOrderingField', 'RestrictedGenericForeignKey', ) @@ -186,3 +191,130 @@ class CounterCacheField(models.BigIntegerField): kwargs["to_model"] = self.to_model_name kwargs["to_field"] = self.to_field_name return name, path, args, kwargs + + +class GenericArrayForeignKey(FieldCacheMixin, models.Field): + """ + Provide a generic many-to-many relation through an 2d array field + """ + + many_to_many = True + many_to_one = False + one_to_many = False + one_to_one = False + + def __init__(self, field, for_concrete_model=True): + super().__init__(editable=False) + self.field = field + self.for_concrete_model = for_concrete_model + self.is_relation = True + + def contribute_to_class(self, cls, name, **kwargs): + super().contribute_to_class(cls, name, private_only=True, **kwargs) + # GenericArrayForeignKey is its own descriptor. + setattr(cls, self.attname, self) + + @cached_property + def cache_name(self): + return self.name + + def get_cache_name(self): + return self.cache_name + + def _get_ids(self, instance): + return getattr(instance, self.field) + + def get_content_type_by_id(self, id=None, using=None): + return ContentType.objects.db_manager(using).get_for_id(id) + + def get_content_type_of_obj(self, obj=None): + return ContentType.objects.db_manager(obj._state.db).get_for_model( + obj, for_concrete_model=self.for_concrete_model + ) + + def get_content_type_for_model(self, using=None, model=None): + return ContentType.objects.db_manager(using).get_for_model( + model, for_concrete_model=self.for_concrete_model + ) + + def get_prefetch_querysets(self, instances, querysets=None): + custom_queryset_dict = {} + if querysets is not None: + for queryset in querysets: + ct_id = self.get_content_type_for_model( + model=queryset.query.model, using=queryset.db + ).pk + if ct_id in custom_queryset_dict: + raise ValueError( + "Only one queryset is allowed for each content type." + ) + custom_queryset_dict[ct_id] = queryset + + # For efficiency, group the instances by content type and then do one + # query per model + fk_dict = defaultdict(set) # type id, db -> model ids + for instance in instances: + for step in self._get_ids(instance): + for ct_id, fk_val in step: + fk_dict[(ct_id, instance._state.db)].add(fk_val) + + rel_objects = [] + for (ct_id, db), fkeys in fk_dict.items(): + if ct_id in custom_queryset_dict: + rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys)) + else: + ct = self.get_content_type_by_id(id=ct_id, using=db) + rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys)) + + # reorganize objects to fix usage + items = { + (self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj + for rel_obj in rel_objects + } + lists = [] + lists_keys = {} + for instance in instances: + data = [] + lists.append(data) + lists_keys[instance] = id(data) + for step in self._get_ids(instance): + nodes = [] + for ct, fk in step: + if rel_obj := items.get((ct, fk, instance._state.db)): + nodes.append(rel_obj) + data.append(nodes) + + return ( + lists, + lambda obj: id(obj), + lambda obj: lists_keys[obj], + True, + self.cache_name, + False, + ) + + def __get__(self, instance, cls=None): + if instance is None: + return self + rel_objects = self.get_cached_value(instance, default=...) + expected_ids = self._get_ids(instance) + # we do not check if cache actual + if rel_objects is not ...: + return rel_objects + # load value + if expected_ids is None: + self.set_cached_value(instance, rel_objects) + return rel_objects + data = [] + for step in self._get_ids(instance): + rel_objects = [] + for ct_id, pk_val in step: + ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db) + try: + rel_obj = ct.get_object_for_this_type(pk=pk_val) + rel_objects.append(rel_obj) + except ObjectDoesNotExist: + pass + data.append(rel_objects) + self.set_cached_value(instance, data) + return data