Move GenericArrayForeignKey to fields module

This commit is contained in:
Andrey Tikhonov 2025-03-12 16:22:11 +01:00
parent d103e13732
commit 7d64d3b5ed
4 changed files with 134 additions and 136 deletions

View File

@ -14,8 +14,7 @@ from dcim.fields import PathField
from dcim.utils import decompile_path_node, object_to_path_node from dcim.utils import decompile_path_node, object_to_path_node
from netbox.models import ChangeLoggedModel, PrimaryModel from netbox.models import ChangeLoggedModel, PrimaryModel
from utilities.conversion import to_meters from utilities.conversion import to_meters
from utilities.fields import ColorField from utilities.fields import ColorField, GenericArrayForeignKey
from utilities.generics.field import GenericArrayForeignKey
from utilities.querysets import RestrictedQuerySet from utilities.querysets import RestrictedQuerySet
from wireless.models import WirelessLink from wireless.models import WirelessLink
from .device_components import FrontPort, RearPort, PathEndpoint from .device_components import FrontPort, RearPort, PathEndpoint

View File

@ -1,7 +1,11 @@
from collections import defaultdict from collections import defaultdict
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db import models from django.db import models
from django.db.models.fields.mixins import FieldCacheMixin
from django.utils.functional import cached_property
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -11,6 +15,7 @@ from .validators import ColorValidator
__all__ = ( __all__ = (
'ColorField', 'ColorField',
'CounterCacheField', 'CounterCacheField',
'GenericArrayForeignKey',
'NaturalOrderingField', 'NaturalOrderingField',
'RestrictedGenericForeignKey', 'RestrictedGenericForeignKey',
) )
@ -186,3 +191,131 @@ class CounterCacheField(models.BigIntegerField):
kwargs["to_model"] = self.to_model_name kwargs["to_model"] = self.to_model_name
kwargs["to_field"] = self.to_field_name kwargs["to_field"] = self.to_field_name
return name, path, args, kwargs return name, path, args, kwargs
class GenericArrayForeignKey(FieldCacheMixin, models.Field):
"""
Provide a generic many-to-many relation through an array field
"""
many_to_many = True
many_to_one = False
one_to_many = False
one_to_one = False
def __init__(self, field, for_concrete_model=True):
super().__init__(editable=False)
self.field = field
self.for_concrete_model = for_concrete_model
self.is_relation = True
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, private_only=True, **kwargs)
# GenericForeignKey is its own descriptor.
setattr(cls, self.attname, self)
@cached_property
def cache_name(self):
return self.name
def get_cache_name(self):
return self.cache_name
def _get_ids(self, instance):
return getattr(instance, self.field)
def get_content_type_by_id(self, id=None, using=None):
return ContentType.objects.db_manager(using).get_for_id(id)
def get_content_type_of_obj(self, obj=None):
return ContentType.objects.db_manager(obj._state.db).get_for_model(
obj, for_concrete_model=self.for_concrete_model
)
def get_content_type_for_model(self, using=None, model=None):
return ContentType.objects.db_manager(using).get_for_model(
model, for_concrete_model=self.for_concrete_model
)
def get_prefetch_querysets(self, instances, querysets=None):
custom_queryset_dict = {}
if querysets is not None:
for queryset in querysets:
ct_id = self.get_content_type_for_model(
model=queryset.query.model, using=queryset.db
).pk
if ct_id in custom_queryset_dict:
raise ValueError(
"Only one queryset is allowed for each content type."
)
custom_queryset_dict[ct_id] = queryset
# For efficiency, group the instances by content type and then do one
# query per model
fk_dict = defaultdict(set) # type id, db -> model ids
for instance in instances:
for step in self._get_ids(instance):
for ct_id, fk_val in step:
fk_dict[(ct_id, instance._state.db)].add(fk_val)
rel_objects = []
for (ct_id, db), fkeys in fk_dict.items():
if ct_id in custom_queryset_dict:
rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
else:
ct = self.get_content_type_by_id(id=ct_id, using=db)
rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
# reorganize objects to fix usage
items = {
(self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
for rel_obj in rel_objects
}
lists = []
lists_keys = {}
for instance in instances:
data = []
lists.append(data)
lists_keys[instance] = id(data)
for step in self._get_ids(instance):
nodes = []
for ct, fk in step:
if rel_obj := items.get((ct, fk, instance._state.db)):
nodes.append(rel_obj)
data.append(nodes)
return (
lists,
lambda obj: id(obj),
lambda obj: lists_keys[obj],
True,
self.cache_name,
False,
)
def __get__(self, instance, cls=None):
if instance is None:
return self
rel_objects = self.get_cached_value(instance, default=...)
expected_ids = self._get_ids(instance)
# we do not check if cache actual
if rel_objects is not ...:
return rel_objects
# load value
if expected_ids is None:
self.set_cached_value(instance, rel_objects)
return rel_objects
data = []
for step in self._get_ids(instance):
rel_objects = []
for ct_id, pk_val in step:
ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
try:
rel_obj = ct.get_object_for_this_type(pk=pk_val)
rel_objects.append(rel_obj)
except ObjectDoesNotExist:
pass
data.append(rel_objects)
self.set_cached_value(instance, data)
return data

View File

@ -1,134 +0,0 @@
from collections import defaultdict
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.fields import Field
from django.db.models.fields.mixins import FieldCacheMixin
from django.utils.functional import cached_property
class GenericArrayForeignKey(FieldCacheMixin, Field):
"""
Provide a generic many-to-many relation through an array field
"""
many_to_many = True
many_to_one = False
one_to_many = False
one_to_one = False
def __init__(self, field, for_concrete_model=True):
super().__init__(editable=False)
self.field = field
self.for_concrete_model = for_concrete_model
self.is_relation = True
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, private_only=True, **kwargs)
# GenericForeignKey is its own descriptor.
setattr(cls, self.attname, self)
@cached_property
def cache_name(self):
return self.name
def get_cache_name(self):
return self.cache_name
def _get_ids(self, instance):
return getattr(instance, self.field)
def get_content_type_by_id(self, id=None, using=None):
return ContentType.objects.db_manager(using).get_for_id(id)
def get_content_type_of_obj(self, obj=None):
return ContentType.objects.db_manager(obj._state.db).get_for_model(
obj, for_concrete_model=self.for_concrete_model
)
def get_content_type_for_model(self, using=None, model=None):
return ContentType.objects.db_manager(using).get_for_model(
model, for_concrete_model=self.for_concrete_model
)
def get_prefetch_querysets(self, instances, querysets=None):
custom_queryset_dict = {}
if querysets is not None:
for queryset in querysets:
ct_id = self.get_content_type_for_model(
model=queryset.query.model, using=queryset.db
).pk
if ct_id in custom_queryset_dict:
raise ValueError(
"Only one queryset is allowed for each content type."
)
custom_queryset_dict[ct_id] = queryset
# For efficiency, group the instances by content type and then do one
# query per model
fk_dict = defaultdict(set) # type id, db -> model ids
for instance in instances:
for step in self._get_ids(instance):
for ct_id, fk_val in step:
fk_dict[(ct_id, instance._state.db)].add(fk_val)
rel_objects = []
for (ct_id, db), fkeys in fk_dict.items():
if ct_id in custom_queryset_dict:
rel_objects.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
else:
ct = self.get_content_type_by_id(id=ct_id, using=db)
rel_objects.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
# reorganize objects to fix usage
items = {
(self.get_content_type_of_obj(obj=rel_obj).pk, rel_obj.pk, rel_obj._state.db): rel_obj
for rel_obj in rel_objects
}
lists = []
lists_keys = {}
for instance in instances:
data = []
lists.append(data)
lists_keys[instance] = id(data)
for step in self._get_ids(instance):
nodes = []
for ct, fk in step:
if rel_obj := items.get((ct, fk, instance._state.db)):
nodes.append(rel_obj)
data.append(nodes)
return (
lists,
lambda obj: id(obj),
lambda obj: lists_keys[obj],
True,
self.cache_name,
False,
)
def __get__(self, instance, cls=None):
if instance is None:
return self
rel_objects = self.get_cached_value(instance, default=...)
expected_ids = self._get_ids(instance)
# we do not check if cache actual
if rel_objects is not ...:
return rel_objects
# load value
if expected_ids is None:
self.set_cached_value(instance, rel_objects)
return rel_objects
data = []
for step in self._get_ids(instance):
rel_objects = []
for ct_id, pk_val in step:
ct = self.get_content_type_by_id(id=ct_id, using=instance._state.db)
try:
rel_obj = ct.get_object_for_this_type(pk=pk_val)
rel_objects.append(rel_obj)
except ObjectDoesNotExist:
pass
data.append(rel_objects)
self.set_cached_value(instance, data)
return data