Closes #4835: Support passing multiple initial values for multiple choice fields

This commit is contained in:
Jeremy Stretch 2020-07-08 12:50:12 -04:00
parent 20ee8ec107
commit fec3ee6f08
5 changed files with 54 additions and 21 deletions

View File

@ -5,6 +5,7 @@
### Bug Fixes ### Bug Fixes
* [#4821](https://github.com/netbox-community/netbox/issues/4821) - Restrict group options by selected site when bulk editing VLANs * [#4821](https://github.com/netbox-community/netbox/issues/4821) - Restrict group options by selected site when bulk editing VLANs
* [#4835](https://github.com/netbox-community/netbox/issues/4835) - Support passing multiple initial values for multiple choice fields
--- ---

View File

@ -594,21 +594,24 @@ class DynamicModelChoiceMixin:
filter = django_filters.ModelChoiceFilter filter = django_filters.ModelChoiceFilter
widget = APISelect widget = APISelect
def _get_initial_value(self, initial_data, field_name): def filter_queryset(self, data):
return initial_data.get(field_name) field_name = getattr(self, 'to_field_name') or 'pk'
# If multiple values have been provided, use only the last.
if type(data) in (list, tuple):
data = data[-1]
filter = self.filter(
field_name=field_name
)
return filter.filter(self.queryset, data)
def get_bound_field(self, form, field_name): def get_bound_field(self, form, field_name):
bound_field = BoundField(form, self, field_name) bound_field = BoundField(form, self, field_name)
# Override initial() to allow passing multiple values
bound_field.initial = self._get_initial_value(form.initial, field_name)
# Modify the QuerySet of the field before we return it. Limit choices to any data already bound: Options # Modify the QuerySet of the field before we return it. Limit choices to any data already bound: Options
# will be populated on-demand via the APISelect widget. # will be populated on-demand via the APISelect widget.
data = bound_field.value() data = bound_field.value()
if data: if data:
filter = self.filter(field_name=self.to_field_name or 'pk', queryset=self.queryset) self.queryset = self.filter_queryset(data)
self.queryset = filter.filter(self.queryset, data)
else: else:
self.queryset = self.queryset.none() self.queryset = self.queryset.none()
@ -638,11 +641,16 @@ class DynamicModelMultipleChoiceField(DynamicModelChoiceMixin, forms.ModelMultip
filter = django_filters.ModelMultipleChoiceFilter filter = django_filters.ModelMultipleChoiceFilter
widget = APISelectMultiple widget = APISelectMultiple
def _get_initial_value(self, initial_data, field_name): def filter_queryset(self, data):
# If a QueryDict has been passed as initial form data, get *all* listed values field_name = getattr(self, 'to_field_name') or 'pk'
if hasattr(initial_data, 'getlist'): # Normalize data to a list
return initial_data.getlist(field_name) if type(data) not in (list, tuple):
return initial_data.get(field_name) data = [data]
filter = self.filter(
field_name=field_name,
lookup_expr='in'
)
return filter.filter(self.queryset, data)
class LaxURLField(forms.URLField): class LaxURLField(forms.URLField):

View File

@ -1,15 +1,13 @@
from django.http import QueryDict
from django.test import TestCase from django.test import TestCase
from utilities.utils import deepmerge, dict_to_filter_params from utilities.utils import deepmerge, dict_to_filter_params, normalize_querydict
class DictToFilterParamsTest(TestCase): class DictToFilterParamsTest(TestCase):
""" """
Validate the operation of dict_to_filter_params(). Validate the operation of dict_to_filter_params().
""" """
def setUp(self):
return
def test_dict_to_filter_params(self): def test_dict_to_filter_params(self):
input = { input = {
@ -39,13 +37,21 @@ class DictToFilterParamsTest(TestCase):
self.assertNotEqual(dict_to_filter_params(input), output) self.assertNotEqual(dict_to_filter_params(input), output)
class NormalizeQueryDictTest(TestCase):
"""
Validate normalize_querydict() utility function.
"""
def test_normalize_querydict(self):
self.assertDictEqual(
normalize_querydict(QueryDict('foo=1&bar=2&bar=3&baz=')),
{'foo': '1', 'bar': ['2', '3'], 'baz': ''}
)
class DeepMergeTest(TestCase): class DeepMergeTest(TestCase):
""" """
Validate the behavior of the deepmerge() utility. Validate the behavior of the deepmerge() utility.
""" """
def setUp(self):
return
def test_deepmerge(self): def test_deepmerge(self):
dict1 = { dict1 = {

View File

@ -150,6 +150,24 @@ def dict_to_filter_params(d, prefix=''):
return params return params
def normalize_querydict(querydict):
"""
Convert a QueryDict to a normal, mutable dictionary, preserving list values. For example,
QueryDict('foo=1&bar=2&bar=3&baz=')
becomes:
{'foo': '1', 'bar': ['2', '3'], 'baz': ''}
This function is necessary because QueryDict does not provide any built-in mechanism which preserves multiple
values.
"""
return {
k: v if len(v) > 1 else v[0] for k, v in querydict.lists()
}
def deepmerge(original, new): def deepmerge(original, new):
""" """
Deep merge two dictionaries (new into original) and return a new dict Deep merge two dictionaries (new into original) and return a new dict

View File

@ -27,7 +27,7 @@ from extras.models import CustomField, CustomFieldValue, ExportTemplate
from extras.querysets import CustomFieldQueryset from extras.querysets import CustomFieldQueryset
from utilities.exceptions import AbortTransaction from utilities.exceptions import AbortTransaction
from utilities.forms import BootstrapMixin, CSVDataField, TableConfigForm from utilities.forms import BootstrapMixin, CSVDataField, TableConfigForm
from utilities.utils import csv_format, prepare_cloned_fields from utilities.utils import csv_format, normalize_querydict, prepare_cloned_fields
from .error_handlers import handle_protectederror from .error_handlers import handle_protectederror
from .forms import ConfirmationForm, ImportForm from .forms import ConfirmationForm, ImportForm
from .paginator import EnhancedPaginator, get_paginate_count from .paginator import EnhancedPaginator, get_paginate_count
@ -250,7 +250,7 @@ class ObjectEditView(GetReturnURLMixin, View):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
# Parse initial data manually to avoid setting field values as lists # Parse initial data manually to avoid setting field values as lists
initial_data = {k: request.GET[k] for k in request.GET} initial_data = normalize_querydict(request.GET)
form = self.model_form(instance=self.obj, initial=initial_data) form = self.model_form(instance=self.obj, initial=initial_data)
return render(request, self.template_name, { return render(request, self.template_name, {