Fixes #13606: Fix filtering by null for multiselect custom fields

This commit is contained in:
Jeremy Stretch 2023-12-26 13:15:23 -05:00
parent 031b7540b3
commit 634681a72e
3 changed files with 24 additions and 9 deletions

View File

@ -10,7 +10,6 @@ from django.contrib.postgres.fields import ArrayField
from django.core.validators import RegexValidator, ValidationError from django.core.validators import RegexValidator, ValidationError
from django.db import models from django.db import models
from django.urls import reverse from django.urls import reverse
from django.utils.html import escape
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -571,8 +570,7 @@ class CustomField(CloningMixin, ExportTemplatesMixin, ChangeLoggedModel):
# Multiselect # Multiselect
elif self.type == CustomFieldTypeChoices.TYPE_MULTISELECT: elif self.type == CustomFieldTypeChoices.TYPE_MULTISELECT:
filter_class = filters.MultiValueCharFilter filter_class = filters.MultiValueArrayFilter
kwargs['lookup_expr'] = 'has_key'
# Object # Object
elif self.type == CustomFieldTypeChoices.TYPE_OBJECT: elif self.type == CustomFieldTypeChoices.TYPE_OBJECT:

View File

@ -1329,7 +1329,7 @@ class CustomFieldModelFilterTest(TestCase):
choice_set = CustomFieldChoiceSet.objects.create( choice_set = CustomFieldChoiceSet.objects.create(
name='Custom Field Choice Set 1', name='Custom Field Choice Set 1',
extra_choices=(('a', 'A'), ('b', 'B'), ('c', 'C'), ('x', 'X')) extra_choices=(('a', 'A'), ('b', 'B'), ('c', 'C'))
) )
# Integer filtering # Integer filtering
@ -1435,7 +1435,7 @@ class CustomFieldModelFilterTest(TestCase):
'cf7': 'http://a.example.com', 'cf7': 'http://a.example.com',
'cf8': 'http://a.example.com', 'cf8': 'http://a.example.com',
'cf9': 'A', 'cf9': 'A',
'cf10': ['A', 'X'], 'cf10': ['A', 'B'],
'cf11': manufacturers[0].pk, 'cf11': manufacturers[0].pk,
'cf12': [manufacturers[0].pk, manufacturers[3].pk], 'cf12': [manufacturers[0].pk, manufacturers[3].pk],
}), }),
@ -1449,7 +1449,7 @@ class CustomFieldModelFilterTest(TestCase):
'cf7': 'http://b.example.com', 'cf7': 'http://b.example.com',
'cf8': 'http://b.example.com', 'cf8': 'http://b.example.com',
'cf9': 'B', 'cf9': 'B',
'cf10': ['B', 'X'], 'cf10': ['B', 'C'],
'cf11': manufacturers[1].pk, 'cf11': manufacturers[1].pk,
'cf12': [manufacturers[1].pk, manufacturers[3].pk], 'cf12': [manufacturers[1].pk, manufacturers[3].pk],
}), }),
@ -1463,7 +1463,7 @@ class CustomFieldModelFilterTest(TestCase):
'cf7': 'http://c.example.com', 'cf7': 'http://c.example.com',
'cf8': 'http://c.example.com', 'cf8': 'http://c.example.com',
'cf9': 'C', 'cf9': 'C',
'cf10': ['C', 'X'], 'cf10': None,
'cf11': manufacturers[2].pk, 'cf11': manufacturers[2].pk,
'cf12': [manufacturers[2].pk, manufacturers[3].pk], 'cf12': [manufacturers[2].pk, manufacturers[3].pk],
}), }),
@ -1531,8 +1531,9 @@ class CustomFieldModelFilterTest(TestCase):
self.assertEqual(self.filterset({'cf_cf9': ['A', 'B']}, self.queryset).qs.count(), 2) self.assertEqual(self.filterset({'cf_cf9': ['A', 'B']}, self.queryset).qs.count(), 2)
def test_filter_multiselect(self): def test_filter_multiselect(self):
self.assertEqual(self.filterset({'cf_cf10': ['A', 'B']}, self.queryset).qs.count(), 2) self.assertEqual(self.filterset({'cf_cf10': ['A']}, self.queryset).qs.count(), 1)
self.assertEqual(self.filterset({'cf_cf10': ['X']}, self.queryset).qs.count(), 3) self.assertEqual(self.filterset({'cf_cf10': ['A', 'C']}, self.queryset).qs.count(), 2)
self.assertEqual(self.filterset({'cf_cf10': ['null']}, self.queryset).qs.count(), 1)
def test_filter_object(self): def test_filter_object(self):
manufacturer_ids = Manufacturer.objects.values_list('id', flat=True) manufacturer_ids = Manufacturer.objects.values_list('id', flat=True)

View File

@ -9,6 +9,7 @@ from drf_spectacular.types import OpenApiTypes
__all__ = ( __all__ = (
'ContentTypeFilter', 'ContentTypeFilter',
'MACAddressFilter', 'MACAddressFilter',
'MultiValueArrayFilter',
'MultiValueCharFilter', 'MultiValueCharFilter',
'MultiValueDateFilter', 'MultiValueDateFilter',
'MultiValueDateTimeFilter', 'MultiValueDateTimeFilter',
@ -85,6 +86,21 @@ class MultiValueTimeFilter(django_filters.MultipleChoiceFilter):
field_class = multivalue_field_factory(forms.TimeField) field_class = multivalue_field_factory(forms.TimeField)
@extend_schema_field(OpenApiTypes.STR)
class MultiValueArrayFilter(django_filters.MultipleChoiceFilter):
field_class = multivalue_field_factory(forms.CharField)
def __init__(self, *args, lookup_expr='contains', **kwargs):
# Set default lookup_expr to 'contains'
super().__init__(*args, lookup_expr=lookup_expr, **kwargs)
def get_filter_predicate(self, v):
# If filtering for null values, ignore lookup_expr
if v is None:
return {self.field_name: None}
return super().get_filter_predicate(v)
class MACAddressFilter(django_filters.CharFilter): class MACAddressFilter(django_filters.CharFilter):
pass pass