Extend logic for validating filter class

This commit is contained in:
Jeremy Stretch 2024-03-11 15:35:40 -04:00
parent a136030094
commit 313e63622b
6 changed files with 99 additions and 82 deletions

View File

@ -91,8 +91,9 @@ class EventRuleFilterSet(NetBoxModelFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
object_type_id = MultiValueNumberFilter( object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_types__id' queryset=ObjectType.objects.all(),
field_name='object_types'
) )
object_type = ContentTypeFilter( object_type = ContentTypeFilter(
field_name='object_types' field_name='object_types'
@ -128,14 +129,16 @@ class CustomFieldFilterSet(ChangeLoggedModelFilterSet):
type = django_filters.MultipleChoiceFilter( type = django_filters.MultipleChoiceFilter(
choices=CustomFieldTypeChoices choices=CustomFieldTypeChoices
) )
object_type_id = MultiValueNumberFilter( object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_types__id' queryset=ObjectType.objects.all(),
field_name='object_types'
) )
object_type = ContentTypeFilter( object_type = ContentTypeFilter(
field_name='object_types' field_name='object_types'
) )
related_object_type_id = MultiValueNumberFilter( related_object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='related_object_type__id' queryset=ObjectType.objects.all(),
field_name='related_object_type'
) )
related_object_type = ContentTypeFilter() related_object_type = ContentTypeFilter()
choice_set_id = django_filters.ModelMultipleChoiceFilter( choice_set_id = django_filters.ModelMultipleChoiceFilter(
@ -199,8 +202,9 @@ class CustomLinkFilterSet(ChangeLoggedModelFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
object_type_id = MultiValueNumberFilter( object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_types__id' queryset=ObjectType.objects.all(),
field_name='object_types'
) )
object_type = ContentTypeFilter( object_type = ContentTypeFilter(
field_name='object_types' field_name='object_types'
@ -228,8 +232,9 @@ class ExportTemplateFilterSet(ChangeLoggedModelFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
object_type_id = MultiValueNumberFilter( object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_types__id' queryset=ObjectType.objects.all(),
field_name='object_types'
) )
object_type = ContentTypeFilter( object_type = ContentTypeFilter(
field_name='object_types' field_name='object_types'
@ -264,8 +269,9 @@ class SavedFilterFilterSet(ChangeLoggedModelFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
object_type_id = MultiValueNumberFilter( object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_types__id' queryset=ObjectType.objects.all(),
field_name='object_types'
) )
object_type = ContentTypeFilter( object_type = ContentTypeFilter(
field_name='object_types' field_name='object_types'

View File

@ -198,8 +198,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VRF.objects.all() queryset = VRF.objects.all()
filterset = VRFFilterSet filterset = VRFFilterSet
@staticmethod def get_m2m_filter_name(self, field):
def get_m2m_filter_name(field):
# Override filter names for import & export RouteTargets # Override filter names for import & export RouteTargets
if field.name == 'import_targets': if field.name == 'import_targets':
return 'import_target' return 'import_target'
@ -303,8 +302,7 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = RouteTarget.objects.all() queryset = RouteTarget.objects.all()
filterset = RouteTargetFilterSet filterset = RouteTargetFilterSet
@staticmethod def get_m2m_filter_name(self, field):
def get_m2m_filter_name(field):
# Override filter names for import & export VRFs and L2VPNs # Override filter names for import & export VRFs and L2VPNs
if field.name == 'importing_vrfs': if field.name == 'importing_vrfs':
return 'importing_vrf' return 'importing_vrf'

View File

@ -3,9 +3,10 @@ from django.contrib.auth import get_user_model
from django.db.models import Q from django.db.models import Q
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from core.models import ObjectType
from netbox.filtersets import BaseFilterSet from netbox.filtersets import BaseFilterSet
from users.models import Group, ObjectPermission, Token from users.models import Group, ObjectPermission, Token
from utilities.filters import ContentTypeFilter, MultiValueNumberFilter from utilities.filters import ContentTypeFilter
__all__ = ( __all__ = (
'GroupFilterSet', 'GroupFilterSet',
@ -134,8 +135,9 @@ class ObjectPermissionFilterSet(BaseFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
object_type_id = MultiValueNumberFilter( object_type_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_types__id' queryset=ObjectType.objects.all(),
field_name='object_types'
) )
object_type = ContentTypeFilter( object_type = ContentTypeFilter(
field_name='object_types' field_name='object_types'

View File

@ -8,7 +8,9 @@ from django.contrib.contenttypes.models import ContentType
from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
from taggit.managers import TaggableManager from taggit.managers import TaggableManager
from utilities.filters import TreeNodeMultipleChoiceFilter
from extras.filters import TagFilter
from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
from core.models import ObjectType from core.models import ObjectType
@ -46,8 +48,7 @@ class BaseFilterSetTests:
filterset = None filterset = None
ignore_fields = tuple() ignore_fields = tuple()
@staticmethod def get_m2m_filter_name(self, field):
def get_m2m_filter_name(field):
""" """
Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
cases may override this method to prescribe deviations for specific fields. cases may override this method to prescribe deviations for specific fields.
@ -55,20 +56,50 @@ class BaseFilterSetTests:
related_model_name = field.related_model._meta.verbose_name related_model_name = field.related_model._meta.verbose_name
return related_model_name.lower().replace(' ', '_') return related_model_name.lower().replace(' ', '_')
@staticmethod def get_filters_for_model_field(self, field):
def get_filter_class_for_field(field): """
Given a model field, return an iterable of (name, class) for each filter that should be defined on
the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
"""
# ForeignKey & OneToOneField # ForeignKey & OneToOneField
if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel: if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
# Relationships to ContentType (used as part of a GFK) do not need a filter
if field.related_model is ContentType:
return [(None, None)]
# ForeignKeys to ObjectType need two filters: 'app.model' & PK
if field.related_model is ObjectType:
return [
(field.name, ContentTypeFilter),
(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter),
]
# ForeignKey to an MPTT-enabled model # ForeignKey to an MPTT-enabled model
if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model: if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
return TreeNodeMultipleChoiceFilter return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)]
return django_filters.ModelMultipleChoiceFilter return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)]
# Many-to-many relationships (forward & backward)
elif type(field) in (ManyToManyField, ManyToManyRel):
filter_name = self.get_m2m_filter_name(field)
# ManyToManyFields to ObjectType need two filters: 'app.model' & PK
if field.related_model is ObjectType:
return [
(filter_name, ContentTypeFilter),
(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
]
return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
# Tag manager
if type(field) is TaggableManager:
return [('tag', TagFilter)]
# Unable to determine the correct filter class # Unable to determine the correct filter class
return None return [(field.name, None)]
def test_id(self): def test_id(self):
""" """
@ -111,57 +142,32 @@ class BaseFilterSetTests:
if type(model_field) is ManyToOneRel: if type(model_field) is ManyToOneRel:
continue continue
# One-to-one & one-to-many relationships
if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
# Relationships to ContentType (used as part of a GFK) do not need a filter
if model_field.related_model is ContentType:
continue
# Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
if model_field.related_model is ObjectType:
filter_name = model_field.name
else:
filter_name = f'{model_field.name}_id'
self.assertIn(
filter_name,
filters,
f'No filter defined for {filter_name} ({model_field.name})!'
)
if filter_class := self.get_filter_class_for_field(model_field):
self.assertIs(
type(filters[filter_name]),
filter_class,
f"Invalid filter class for {filter_name}!"
)
# Many-to-many relationships (forward & backward)
elif type(model_field) in (ManyToManyField, ManyToManyRel):
filter_name = self.get_m2m_filter_name(model_field)
filter_name = f'{filter_name}_id'
self.assertIn(
filter_name,
filters,
f'No filter defined for {filter_name} ({model_field.name})!'
)
# TODO: Generic relationships # TODO: Generic relationships
elif type(model_field) in (GenericForeignKey, GenericRelation): if type(model_field) in (GenericForeignKey, GenericRelation):
continue continue
# Tags for filter_name, filter_class in self.get_filters_for_model_field(model_field):
elif type(model_field) is TaggableManager:
self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
# All other fields if filter_name is None:
else: # Field is exempt
continue
# Check that the filter is defined
self.assertIn( self.assertIn(
model_field.name, filter_name,
filters, filters.keys(),
f'No defined found for {model_field.name} ({type(model_field)})!' f'No filter defined for {filter_name} ({model_field.name})!'
) )
# Check that the filter class is correct
filter = filters[filter_name]
if filter_class is not None:
self.assertIs(
type(filter),
filter_class,
f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
)
class ChangeLoggedFilterSetTests(BaseFilterSetTests): class ChangeLoggedFilterSetTests(BaseFilterSetTests):

View File

@ -169,11 +169,14 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
mode = django_filters.MultipleChoiceFilter( mode = django_filters.MultipleChoiceFilter(
choices=IKEModeChoices choices=IKEModeChoices
) )
ike_proposal_id = MultiValueNumberFilter( ike_proposal_id = django_filters.ModelMultipleChoiceFilter(
field_name='proposals__id' field_name='proposals',
queryset=IKEProposal.objects.all()
) )
ike_proposal = MultiValueCharFilter( ike_proposal = django_filters.ModelMultipleChoiceFilter(
field_name='proposals__name' field_name='proposals__name',
queryset=IKEProposal.objects.all(),
to_field_name='name'
) )
# TODO: Remove in v4.1 # TODO: Remove in v4.1
@ -231,11 +234,14 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet):
pfs_group = django_filters.MultipleChoiceFilter( pfs_group = django_filters.MultipleChoiceFilter(
choices=DHGroupChoices choices=DHGroupChoices
) )
ipsec_proposal_id = MultiValueNumberFilter( ipsec_proposal_id = django_filters.ModelMultipleChoiceFilter(
field_name='proposals__id' field_name='proposals',
queryset=IPSecProposal.objects.all()
) )
ipsec_proposal = MultiValueCharFilter( ipsec_proposal = django_filters.ModelMultipleChoiceFilter(
field_name='proposals__name' field_name='proposals__name',
queryset=IPSecProposal.objects.all(),
to_field_name='name'
) )
# TODO: Remove in v4.1 # TODO: Remove in v4.1

View File

@ -743,8 +743,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = L2VPN.objects.all() queryset = L2VPN.objects.all()
filterset = L2VPNFilterSet filterset = L2VPNFilterSet
@staticmethod def get_m2m_filter_name(self, field):
def get_m2m_filter_name(field):
# Override filter names for import & export RouteTargets # Override filter names for import & export RouteTargets
if field.name == 'import_targets': if field.name == 'import_targets':
return 'import_target' return 'import_target'