Merge pull request #18826 from Tishka17/fix/generic_prefetch_4.2

Prefetch interface data for REST API on netbox 4.2
This commit is contained in:
bctiemann 2025-03-12 18:55:58 -04:00 committed by GitHub
commit b1e7d7c76b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 161 additions and 48 deletions

View File

@ -1,3 +1,4 @@
from django.contrib.contenttypes.prefetch import GenericPrefetch
from django.http import Http404, HttpResponse from django.http import Http404, HttpResponse
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
@ -442,7 +443,18 @@ class PowerOutletViewSet(PathEndpointMixin, NetBoxModelViewSet):
class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet): class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet):
queryset = Interface.objects.prefetch_related( queryset = Interface.objects.prefetch_related(
'_path', 'cable__terminations', GenericPrefetch(
"cable__terminations__termination",
[
Interface.objects.select_related("device", "cable"),
],
),
GenericPrefetch(
"_path__path_objects",
[
Interface.objects.select_related("device", "cable"),
],
),
'l2vpn_terminations', # Referenced by InterfaceSerializer.l2vpn_termination 'l2vpn_terminations', # Referenced by InterfaceSerializer.l2vpn_termination
'ip_addresses', # Referenced by Interface.count_ipaddresses() 'ip_addresses', # Referenced by Interface.count_ipaddresses()
'fhrp_group_assignments', # Referenced by Interface.count_fhrp_groups() 'fhrp_group_assignments', # Referenced by Interface.count_fhrp_groups()

View File

@ -1,5 +1,4 @@
import itertools import itertools
from collections import defaultdict
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
@ -16,7 +15,7 @@ from dcim.utils import decompile_path_node, object_to_path_node
from netbox.models import ChangeLoggedModel, PrimaryModel from netbox.models import ChangeLoggedModel, PrimaryModel
from utilities.conversion import to_meters from utilities.conversion import to_meters
from utilities.exceptions import AbortRequest from utilities.exceptions import AbortRequest
from utilities.fields import ColorField from utilities.fields import ColorField, GenericArrayForeignKey
from utilities.querysets import RestrictedQuerySet from utilities.querysets import RestrictedQuerySet
from wireless.models import WirelessLink from wireless.models import WirelessLink
from .device_components import FrontPort, RearPort, PathEndpoint from .device_components import FrontPort, RearPort, PathEndpoint
@ -494,13 +493,16 @@ class CablePath(models.Model):
return ObjectType.objects.get_for_id(ct_id) return ObjectType.objects.get_for_id(ct_id)
@property @property
def path_objects(self): def _path_decompiled(self):
""" res = []
Cache and return the complete path as lists of objects, derived from their annotation within the path. for step in self.path:
""" nodes = []
if not hasattr(self, '_path_objects'): for node in step:
self._path_objects = self._get_path() nodes.append(decompile_path_node(node))
return self._path_objects res.append(nodes)
return res
path_objects = GenericArrayForeignKey("_path_decompiled")
@property @property
def origins(self): def origins(self):
@ -757,42 +759,6 @@ class CablePath(models.Model):
self.delete() self.delete()
retrace.alters_data = True retrace.alters_data = True
def _get_path(self):
"""
Return the path as a list of prefetched objects.
"""
# Compile a list of IDs to prefetch for each type of model in the path
to_prefetch = defaultdict(list)
for node in self._nodes:
ct_id, object_id = decompile_path_node(node)
to_prefetch[ct_id].append(object_id)
# Prefetch path objects using one query per model type. Prefetch related devices where appropriate.
prefetched = {}
for ct_id, object_ids in to_prefetch.items():
model_class = ObjectType.objects.get_for_id(ct_id).model_class()
queryset = model_class.objects.filter(pk__in=object_ids)
if hasattr(model_class, 'device'):
queryset = queryset.prefetch_related('device')
prefetched[ct_id] = {
obj.id: obj for obj in queryset
}
# Replicate the path using the prefetched objects.
path = []
for step in self.path:
nodes = []
for node in step:
ct_id, object_id = decompile_path_node(node)
try:
nodes.append(prefetched[ct_id][object_id])
except KeyError:
# Ignore stale (deleted) object IDs
pass
path.append(nodes)
return path
def get_cable_ids(self): def get_cable_ids(self):
""" """
Return all Cable IDs within the path. Return all Cable IDs within the path.

View File

@ -184,8 +184,11 @@ class CabledObjectModel(models.Model):
@cached_property @cached_property
def link_peers(self): def link_peers(self):
if self.cable: if self.cable:
peers = self.cable.terminations.exclude(cable_end=self.cable_end).prefetch_related('termination') return [
return [peer.termination for peer in peers] peer.termination
for peer in self.cable.terminations.all()
if peer.cable_end != self.cable_end
]
return [] return []
@property @property

View File

@ -1,7 +1,11 @@
from collections import defaultdict from collections import defaultdict
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db import models from django.db import models
from django.db.models.fields.mixins import FieldCacheMixin
from django.utils.functional import cached_property
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -11,6 +15,7 @@ from .validators import ColorValidator
__all__ = ( __all__ = (
'ColorField', 'ColorField',
'CounterCacheField', 'CounterCacheField',
'GenericArrayForeignKey',
'NaturalOrderingField', 'NaturalOrderingField',
'RestrictedGenericForeignKey', 'RestrictedGenericForeignKey',
) )
@ -186,3 +191,130 @@ class CounterCacheField(models.BigIntegerField):
kwargs["to_model"] = self.to_model_name kwargs["to_model"] = self.to_model_name
kwargs["to_field"] = self.to_field_name kwargs["to_field"] = self.to_field_name
return name, path, args, kwargs return name, path, args, kwargs
class GenericArrayForeignKey(FieldCacheMixin, models.Field):
"""
Provide a generic many-to-many relation through an 2d array field
"""
many_to_many = True
many_to_one = False
one_to_many = False
one_to_one = False
def __init__(self, field, for_concrete_model=True):
super().__init__(editable=False)
self.field = field
self.for_concrete_model = for_concrete_model
self.is_relation = True
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, private_only=True, **kwargs)
# GenericArrayForeignKey is its own descriptor.
setattr(cls, self.attname, self)
@cached_property
def cache_name(self):
return self.name
def get_cache_name(self):
return self.cache_name
def _get_ids(self, instance):
return getattr(instance, self.field)
def get_content_type_by_id(self, id=None, using=None):
return ContentType.objects.db_manager(using).get_for_id(id)
def get_content_type_of_obj(self, obj=None):
return ContentType.objects.db_manager(obj._state.db).get_for_model(
obj, for_concrete_model=self.for_concrete_model
)
def get_content_type_for_model(self, using=None, model=None):
return ContentType.objects.db_manager(using).get_for_model(
model, for_concrete_model=self.for_concrete_model
)
def get_prefetch_querysets(self, instances, querysets=None):
custom_queryset_dict = {}
if querysets is not None:
for queryset in querysets:
ct_id = self.get_content_type_for_model(
model=queryset.query.model, using=queryset.db
).pk
if ct_id in custom_queryset_dict:
raise ValueError(
"Only one queryset is allowed for each content type."
)
custom_queryset_dict[ct_id] = queryset
# For efficiency, group the instances by content type and then do one
# query per model
fk_dict = defaultdict(set) # type id, db -> model ids
for instance in instances:
for step in self._get_ids(instance):
for ct_id, fk_val in step:
fk_dict[(ct_id, instance._state.db)].add(fk_val)
rel_objects = []
for (ct_id, db), fkeys in fk_dict.items():
if ct_id in custom_queryset_dict:
rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
else:
ct = self.get_content_type_by_id(id=ct_id, using=db)
rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
# reorganize objects to fix usage
items = {
(self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
for rel_obj in rel_objects
}
lists = []
lists_keys = {}
for instance in instances:
data = []
lists.append(data)
lists_keys[instance] = id(data)
for step in self._get_ids(instance):
nodes = []
for ct, fk in step:
if rel_obj := items.get((ct, fk, instance._state.db)):
nodes.append(rel_obj)
data.append(nodes)
return (
lists,
lambda obj: id(obj),
lambda obj: lists_keys[obj],
True,
self.cache_name,
False,
)
def __get__(self, instance, cls=None):
if instance is None:
return self
rel_objects = self.get_cached_value(instance, default=...)
expected_ids = self._get_ids(instance)
# we do not check if cache actual
if rel_objects is not ...:
return rel_objects
# load value
if expected_ids is None:
self.set_cached_value(instance, rel_objects)
return rel_objects
data = []
for step in self._get_ids(instance):
rel_objects = []
for ct_id, pk_val in step:
ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
try:
rel_obj = ct.get_object_for_this_type(pk=pk_val)
rel_objects.append(rel_obj)
except ObjectDoesNotExist:
pass
data.append(rel_objects)
self.set_cached_value(instance, data)
return data