mirror of
https://github.com/netbox-community/netbox.git
synced 2025-07-25 18:08:38 -06:00
Extend logic for validating filter class
This commit is contained in:
parent
a136030094
commit
313e63622b
@ -91,8 +91,9 @@ class EventRuleFilterSet(NetBoxModelFilterSet):
|
|||||||
method='search',
|
method='search',
|
||||||
label=_('Search'),
|
label=_('Search'),
|
||||||
)
|
)
|
||||||
object_type_id = MultiValueNumberFilter(
|
object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='object_types__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
object_type = ContentTypeFilter(
|
object_type = ContentTypeFilter(
|
||||||
field_name='object_types'
|
field_name='object_types'
|
||||||
@ -128,14 +129,16 @@ class CustomFieldFilterSet(ChangeLoggedModelFilterSet):
|
|||||||
type = django_filters.MultipleChoiceFilter(
|
type = django_filters.MultipleChoiceFilter(
|
||||||
choices=CustomFieldTypeChoices
|
choices=CustomFieldTypeChoices
|
||||||
)
|
)
|
||||||
object_type_id = MultiValueNumberFilter(
|
object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='object_types__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
object_type = ContentTypeFilter(
|
object_type = ContentTypeFilter(
|
||||||
field_name='object_types'
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
related_object_type_id = MultiValueNumberFilter(
|
related_object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='related_object_type__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='related_object_type'
|
||||||
)
|
)
|
||||||
related_object_type = ContentTypeFilter()
|
related_object_type = ContentTypeFilter()
|
||||||
choice_set_id = django_filters.ModelMultipleChoiceFilter(
|
choice_set_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
@ -199,8 +202,9 @@ class CustomLinkFilterSet(ChangeLoggedModelFilterSet):
|
|||||||
method='search',
|
method='search',
|
||||||
label=_('Search'),
|
label=_('Search'),
|
||||||
)
|
)
|
||||||
object_type_id = MultiValueNumberFilter(
|
object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='object_types__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
object_type = ContentTypeFilter(
|
object_type = ContentTypeFilter(
|
||||||
field_name='object_types'
|
field_name='object_types'
|
||||||
@ -228,8 +232,9 @@ class ExportTemplateFilterSet(ChangeLoggedModelFilterSet):
|
|||||||
method='search',
|
method='search',
|
||||||
label=_('Search'),
|
label=_('Search'),
|
||||||
)
|
)
|
||||||
object_type_id = MultiValueNumberFilter(
|
object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='object_types__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
object_type = ContentTypeFilter(
|
object_type = ContentTypeFilter(
|
||||||
field_name='object_types'
|
field_name='object_types'
|
||||||
@ -264,8 +269,9 @@ class SavedFilterFilterSet(ChangeLoggedModelFilterSet):
|
|||||||
method='search',
|
method='search',
|
||||||
label=_('Search'),
|
label=_('Search'),
|
||||||
)
|
)
|
||||||
object_type_id = MultiValueNumberFilter(
|
object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='object_types__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
object_type = ContentTypeFilter(
|
object_type = ContentTypeFilter(
|
||||||
field_name='object_types'
|
field_name='object_types'
|
||||||
|
@ -198,8 +198,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
|
|||||||
queryset = VRF.objects.all()
|
queryset = VRF.objects.all()
|
||||||
filterset = VRFFilterSet
|
filterset = VRFFilterSet
|
||||||
|
|
||||||
@staticmethod
|
def get_m2m_filter_name(self, field):
|
||||||
def get_m2m_filter_name(field):
|
|
||||||
# Override filter names for import & export RouteTargets
|
# Override filter names for import & export RouteTargets
|
||||||
if field.name == 'import_targets':
|
if field.name == 'import_targets':
|
||||||
return 'import_target'
|
return 'import_target'
|
||||||
@ -303,8 +302,7 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
|
|||||||
queryset = RouteTarget.objects.all()
|
queryset = RouteTarget.objects.all()
|
||||||
filterset = RouteTargetFilterSet
|
filterset = RouteTargetFilterSet
|
||||||
|
|
||||||
@staticmethod
|
def get_m2m_filter_name(self, field):
|
||||||
def get_m2m_filter_name(field):
|
|
||||||
# Override filter names for import & export VRFs and L2VPNs
|
# Override filter names for import & export VRFs and L2VPNs
|
||||||
if field.name == 'importing_vrfs':
|
if field.name == 'importing_vrfs':
|
||||||
return 'importing_vrf'
|
return 'importing_vrf'
|
||||||
|
@ -3,9 +3,10 @@ from django.contrib.auth import get_user_model
|
|||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from core.models import ObjectType
|
||||||
from netbox.filtersets import BaseFilterSet
|
from netbox.filtersets import BaseFilterSet
|
||||||
from users.models import Group, ObjectPermission, Token
|
from users.models import Group, ObjectPermission, Token
|
||||||
from utilities.filters import ContentTypeFilter, MultiValueNumberFilter
|
from utilities.filters import ContentTypeFilter
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'GroupFilterSet',
|
'GroupFilterSet',
|
||||||
@ -134,8 +135,9 @@ class ObjectPermissionFilterSet(BaseFilterSet):
|
|||||||
method='search',
|
method='search',
|
||||||
label=_('Search'),
|
label=_('Search'),
|
||||||
)
|
)
|
||||||
object_type_id = MultiValueNumberFilter(
|
object_type_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='object_types__id'
|
queryset=ObjectType.objects.all(),
|
||||||
|
field_name='object_types'
|
||||||
)
|
)
|
||||||
object_type = ContentTypeFilter(
|
object_type = ContentTypeFilter(
|
||||||
field_name='object_types'
|
field_name='object_types'
|
||||||
|
@ -8,7 +8,9 @@ from django.contrib.contenttypes.models import ContentType
|
|||||||
from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
|
from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
|
||||||
from django.utils.module_loading import import_string
|
from django.utils.module_loading import import_string
|
||||||
from taggit.managers import TaggableManager
|
from taggit.managers import TaggableManager
|
||||||
from utilities.filters import TreeNodeMultipleChoiceFilter
|
|
||||||
|
from extras.filters import TagFilter
|
||||||
|
from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
|
||||||
|
|
||||||
from core.models import ObjectType
|
from core.models import ObjectType
|
||||||
|
|
||||||
@ -46,8 +48,7 @@ class BaseFilterSetTests:
|
|||||||
filterset = None
|
filterset = None
|
||||||
ignore_fields = tuple()
|
ignore_fields = tuple()
|
||||||
|
|
||||||
@staticmethod
|
def get_m2m_filter_name(self, field):
|
||||||
def get_m2m_filter_name(field):
|
|
||||||
"""
|
"""
|
||||||
Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
|
Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
|
||||||
cases may override this method to prescribe deviations for specific fields.
|
cases may override this method to prescribe deviations for specific fields.
|
||||||
@ -55,20 +56,50 @@ class BaseFilterSetTests:
|
|||||||
related_model_name = field.related_model._meta.verbose_name
|
related_model_name = field.related_model._meta.verbose_name
|
||||||
return related_model_name.lower().replace(' ', '_')
|
return related_model_name.lower().replace(' ', '_')
|
||||||
|
|
||||||
@staticmethod
|
def get_filters_for_model_field(self, field):
|
||||||
def get_filter_class_for_field(field):
|
"""
|
||||||
|
Given a model field, return an iterable of (name, class) for each filter that should be defined on
|
||||||
|
the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
|
||||||
|
"""
|
||||||
# ForeignKey & OneToOneField
|
# ForeignKey & OneToOneField
|
||||||
if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
|
if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
|
||||||
|
|
||||||
|
# Relationships to ContentType (used as part of a GFK) do not need a filter
|
||||||
|
if field.related_model is ContentType:
|
||||||
|
return [(None, None)]
|
||||||
|
|
||||||
|
# ForeignKeys to ObjectType need two filters: 'app.model' & PK
|
||||||
|
if field.related_model is ObjectType:
|
||||||
|
return [
|
||||||
|
(field.name, ContentTypeFilter),
|
||||||
|
(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter),
|
||||||
|
]
|
||||||
|
|
||||||
# ForeignKey to an MPTT-enabled model
|
# ForeignKey to an MPTT-enabled model
|
||||||
if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
|
if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
|
||||||
return TreeNodeMultipleChoiceFilter
|
return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)]
|
||||||
|
|
||||||
return django_filters.ModelMultipleChoiceFilter
|
return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)]
|
||||||
|
|
||||||
|
# Many-to-many relationships (forward & backward)
|
||||||
|
elif type(field) in (ManyToManyField, ManyToManyRel):
|
||||||
|
filter_name = self.get_m2m_filter_name(field)
|
||||||
|
|
||||||
|
# ManyToManyFields to ObjectType need two filters: 'app.model' & PK
|
||||||
|
if field.related_model is ObjectType:
|
||||||
|
return [
|
||||||
|
(filter_name, ContentTypeFilter),
|
||||||
|
(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
|
||||||
|
|
||||||
|
# Tag manager
|
||||||
|
if type(field) is TaggableManager:
|
||||||
|
return [('tag', TagFilter)]
|
||||||
|
|
||||||
# Unable to determine the correct filter class
|
# Unable to determine the correct filter class
|
||||||
return None
|
return [(field.name, None)]
|
||||||
|
|
||||||
def test_id(self):
|
def test_id(self):
|
||||||
"""
|
"""
|
||||||
@ -111,57 +142,32 @@ class BaseFilterSetTests:
|
|||||||
if type(model_field) is ManyToOneRel:
|
if type(model_field) is ManyToOneRel:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# One-to-one & one-to-many relationships
|
|
||||||
if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
|
|
||||||
|
|
||||||
# Relationships to ContentType (used as part of a GFK) do not need a filter
|
|
||||||
if model_field.related_model is ContentType:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
|
|
||||||
if model_field.related_model is ObjectType:
|
|
||||||
filter_name = model_field.name
|
|
||||||
else:
|
|
||||||
filter_name = f'{model_field.name}_id'
|
|
||||||
|
|
||||||
self.assertIn(
|
|
||||||
filter_name,
|
|
||||||
filters,
|
|
||||||
f'No filter defined for {filter_name} ({model_field.name})!'
|
|
||||||
)
|
|
||||||
if filter_class := self.get_filter_class_for_field(model_field):
|
|
||||||
self.assertIs(
|
|
||||||
type(filters[filter_name]),
|
|
||||||
filter_class,
|
|
||||||
f"Invalid filter class for {filter_name}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Many-to-many relationships (forward & backward)
|
|
||||||
elif type(model_field) in (ManyToManyField, ManyToManyRel):
|
|
||||||
filter_name = self.get_m2m_filter_name(model_field)
|
|
||||||
filter_name = f'{filter_name}_id'
|
|
||||||
self.assertIn(
|
|
||||||
filter_name,
|
|
||||||
filters,
|
|
||||||
f'No filter defined for {filter_name} ({model_field.name})!'
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Generic relationships
|
# TODO: Generic relationships
|
||||||
elif type(model_field) in (GenericForeignKey, GenericRelation):
|
if type(model_field) in (GenericForeignKey, GenericRelation):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Tags
|
for filter_name, filter_class in self.get_filters_for_model_field(model_field):
|
||||||
elif type(model_field) is TaggableManager:
|
|
||||||
self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
|
|
||||||
|
|
||||||
# All other fields
|
if filter_name is None:
|
||||||
else:
|
# Field is exempt
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check that the filter is defined
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
model_field.name,
|
filter_name,
|
||||||
filters,
|
filters.keys(),
|
||||||
f'No defined found for {model_field.name} ({type(model_field)})!'
|
f'No filter defined for {filter_name} ({model_field.name})!'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check that the filter class is correct
|
||||||
|
filter = filters[filter_name]
|
||||||
|
if filter_class is not None:
|
||||||
|
self.assertIs(
|
||||||
|
type(filter),
|
||||||
|
filter_class,
|
||||||
|
f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChangeLoggedFilterSetTests(BaseFilterSetTests):
|
class ChangeLoggedFilterSetTests(BaseFilterSetTests):
|
||||||
|
|
||||||
|
@ -169,11 +169,14 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
|
|||||||
mode = django_filters.MultipleChoiceFilter(
|
mode = django_filters.MultipleChoiceFilter(
|
||||||
choices=IKEModeChoices
|
choices=IKEModeChoices
|
||||||
)
|
)
|
||||||
ike_proposal_id = MultiValueNumberFilter(
|
ike_proposal_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='proposals__id'
|
field_name='proposals',
|
||||||
|
queryset=IKEProposal.objects.all()
|
||||||
)
|
)
|
||||||
ike_proposal = MultiValueCharFilter(
|
ike_proposal = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='proposals__name'
|
field_name='proposals__name',
|
||||||
|
queryset=IKEProposal.objects.all(),
|
||||||
|
to_field_name='name'
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Remove in v4.1
|
# TODO: Remove in v4.1
|
||||||
@ -231,11 +234,14 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet):
|
|||||||
pfs_group = django_filters.MultipleChoiceFilter(
|
pfs_group = django_filters.MultipleChoiceFilter(
|
||||||
choices=DHGroupChoices
|
choices=DHGroupChoices
|
||||||
)
|
)
|
||||||
ipsec_proposal_id = MultiValueNumberFilter(
|
ipsec_proposal_id = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='proposals__id'
|
field_name='proposals',
|
||||||
|
queryset=IPSecProposal.objects.all()
|
||||||
)
|
)
|
||||||
ipsec_proposal = MultiValueCharFilter(
|
ipsec_proposal = django_filters.ModelMultipleChoiceFilter(
|
||||||
field_name='proposals__name'
|
field_name='proposals__name',
|
||||||
|
queryset=IPSecProposal.objects.all(),
|
||||||
|
to_field_name='name'
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Remove in v4.1
|
# TODO: Remove in v4.1
|
||||||
|
@ -743,8 +743,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
|
|||||||
queryset = L2VPN.objects.all()
|
queryset = L2VPN.objects.all()
|
||||||
filterset = L2VPNFilterSet
|
filterset = L2VPNFilterSet
|
||||||
|
|
||||||
@staticmethod
|
def get_m2m_filter_name(self, field):
|
||||||
def get_m2m_filter_name(field):
|
|
||||||
# Override filter names for import & export RouteTargets
|
# Override filter names for import & export RouteTargets
|
||||||
if field.name == 'import_targets':
|
if field.name == 'import_targets':
|
||||||
return 'import_target'
|
return 'import_target'
|
||||||
|
Loading…
Reference in New Issue
Block a user