diff --git a/netbox/utilities/constants.py b/netbox/utilities/constants.py index d7f819e8c..1f8e13553 100644 --- a/netbox/utilities/constants.py +++ b/netbox/utilities/constants.py @@ -53,6 +53,10 @@ FILTER_NUMERIC_BASED_LOOKUP_MAP = dict( gt='gt' ) +FILTER_NEGATION_LOOKUP_MAP = dict( + n='exact' +) + FILTER_LOOKUP_HELP_TEXT_MAP = dict( icontains='case insensitive contains', iendswith='case insensitive ends with', diff --git a/netbox/utilities/filters.py b/netbox/utilities/filters.py index d1364cf4b..5e9e1f4c1 100644 --- a/netbox/utilities/filters.py +++ b/netbox/utilities/filters.py @@ -8,7 +8,8 @@ from django_filters.utils import get_model_field, resolve_field from extras.models import Tag from utilities.constants import ( - FILTER_CHAR_BASED_LOOKUP_MAP, FILTER_LOOKUP_HELP_TEXT_MAP, FILTER_NUMERIC_BASED_LOOKUP_MAP + FILTER_CHAR_BASED_LOOKUP_MAP, FILTER_LOOKUP_HELP_TEXT_MAP, FILTER_NEGATION_LOOKUP_MAP, + FILTER_NUMERIC_BASED_LOOKUP_MAP ) @@ -193,15 +194,6 @@ class BaseFilterSet(django_filters.FilterSet): # Choose the lookup expression map based on the filter type if isinstance(existing_filter, ( - django_filters.filters.CharFilter, - django_filters.MultipleChoiceFilter, - MultiValueCharFilter, - MultiValueMACAddressFilter, - TagFilter - )): - lookup_map = FILTER_CHAR_BASED_LOOKUP_MAP - - elif isinstance(existing_filter, ( MultiValueDateFilter, MultiValueDateTimeFilter, MultiValueNumberFilter, @@ -212,13 +204,19 @@ class BaseFilterSet(django_filters.FilterSet): elif isinstance(existing_filter, ( django_filters.ModelChoiceFilter, django_filters.ModelMultipleChoiceFilter, - NumericInFilter, TreeNodeMultipleChoiceFilter, - )): + TagFilter + )) or existing_filter.extra.get('choices'): # These filter types support only negation - lookup_map = dict( - n='exact' - ) + lookup_map = FILTER_NEGATION_LOOKUP_MAP + + elif isinstance(existing_filter, ( + django_filters.filters.CharFilter, + django_filters.MultipleChoiceFilter, + MultiValueCharFilter, + MultiValueMACAddressFilter + )): + lookup_map = FILTER_CHAR_BASED_LOOKUP_MAP else: # Do not augment any other filter types with more lookup expressions @@ -231,6 +229,8 @@ class BaseFilterSet(django_filters.FilterSet): # Create new filters for each lookup expression in the map for lookup_name, lookup_expr in lookup_map.items(): new_filter_name = '{}__{}'.format(existing_filter_name, lookup_name) + if existing_filter.lookup_expr == 'in': + lookup_expr = 'in' # 'in' lookups must remain to avoid unwanted slicing on certain querysets try: if existing_filter_name in cls.declared_filters: @@ -255,7 +255,8 @@ class BaseFilterSet(django_filters.FilterSet): if lookup_name.startswith('n'): # This is a negation filter which requires a queryset.exclude() clause - new_filter.exclude = True + # Of course setting the negation of the existing filter's exclude attribute handles both cases + new_filter.exclude = not existing_filter.exclude new_filters[new_filter_name] = new_filter diff --git a/netbox/utilities/tests/test_filters.py b/netbox/utilities/tests/test_filters.py index 513e11bca..a1cb771a1 100644 --- a/netbox/utilities/tests/test_filters.py +++ b/netbox/utilities/tests/test_filters.py @@ -2,8 +2,9 @@ from django.conf import settings from django.test import TestCase import django_filters +from dcim.filters import SiteFilterSet from dcim.models import Region, Site -from utilities.filters import TreeNodeMultipleChoiceFilter +from utilities.filters import BaseFilterSet, TreeNodeMultipleChoiceFilter class TreeNodeMultipleChoiceFilterTest(TestCase): @@ -60,3 +61,74 @@ class TreeNodeMultipleChoiceFilterTest(TestCase): self.assertEqual(qs.count(), 2) self.assertEqual(qs[0], self.site1) self.assertEqual(qs[1], self.site3) + + +class DynamicFilterLookupExpressionTest(TestCase): + """ + These tests ensure of the utilities.filters.BaseFilterSet.get_filters() method + correctly generates dynamic filter expressions + """ + + def setUp(self): + + super().setUp() + + self.region1 = Region.objects.create(name='Test Region 1', slug='test-region-1') + self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2') + self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='ABC-test-site1-ABC', asn=65001) + self.site2 = Site.objects.create(region=self.region2, name='Test Site 2', slug='def-test-site2-def', asn=65101) + self.site3 = Site.objects.create(region=None, name='Test Site 3', slug='ghi-test-site3-ghi', asn=65201) + + self.queryset = Site.objects.all() + + def test_site_name_negation(self): + params = {'name__n': ['Test Site 1']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_slug_icontains(self): + params = {'slug__ic': ['abc']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 1) + + def test_site_slug_icontains_negation(self): + params = {'slug__nic': ['abc']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_slug_startswith(self): + params = {'slug__isw': ['abc']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 1) + + def test_site_slug_startswith_negation(self): + params = {'slug__nisw': ['abc']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_slug_endswith(self): + params = {'slug__iew': ['abc']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 1) + + def test_site_slug_endswith_negation(self): + params = {'slug__niew': ['abc']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_asn_lt(self): + params = {'asn__lt': [65101]} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 1) + + def test_site_asn_lte(self): + params = {'asn__lte': [65101]} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_asn_gt(self): + params = {'asn__lt': [65101]} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 1) + + def test_site_asn_gte(self): + params = {'asn__gte': [65101]} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_region_negation(self): + params = {'region__n': ['test-region-1']} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2) + + def test_site_region_id_negation(self): + params = {'region_id__n': [self.region1.pk]} + self.assertEqual(SiteFilterSet(params, self.queryset).qs.count(), 2)