From fec3ee6f08b00231ae0065702a94abce09618faa Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Wed, 8 Jul 2020 12:50:12 -0400 Subject: [PATCH] Closes #4835: Support passing multiple initial values for multiple choice fields --- docs/release-notes/version-2.8.md | 1 + netbox/utilities/forms.py | 32 +++++++++++++++++----------- netbox/utilities/tests/test_utils.py | 20 +++++++++++------ netbox/utilities/utils.py | 18 ++++++++++++++++ netbox/utilities/views.py | 4 ++-- 5 files changed, 54 insertions(+), 21 deletions(-) diff --git a/docs/release-notes/version-2.8.md b/docs/release-notes/version-2.8.md index f2dc5374c..7b6d545ee 100644 --- a/docs/release-notes/version-2.8.md +++ b/docs/release-notes/version-2.8.md @@ -5,6 +5,7 @@ ### Bug Fixes * [#4821](https://github.com/netbox-community/netbox/issues/4821) - Restrict group options by selected site when bulk editing VLANs +* [#4835](https://github.com/netbox-community/netbox/issues/4835) - Support passing multiple initial values for multiple choice fields --- diff --git a/netbox/utilities/forms.py b/netbox/utilities/forms.py index 0cc928e83..9ed0cca5c 100644 --- a/netbox/utilities/forms.py +++ b/netbox/utilities/forms.py @@ -594,21 +594,24 @@ class DynamicModelChoiceMixin: filter = django_filters.ModelChoiceFilter widget = APISelect - def _get_initial_value(self, initial_data, field_name): - return initial_data.get(field_name) + def filter_queryset(self, data): + field_name = getattr(self, 'to_field_name') or 'pk' + # If multiple values have been provided, use only the last. + if type(data) in (list, tuple): + data = data[-1] + filter = self.filter( + field_name=field_name + ) + return filter.filter(self.queryset, data) def get_bound_field(self, form, field_name): bound_field = BoundField(form, self, field_name) - # Override initial() to allow passing multiple values - bound_field.initial = self._get_initial_value(form.initial, field_name) - # Modify the QuerySet of the field before we return it. Limit choices to any data already bound: Options # will be populated on-demand via the APISelect widget. data = bound_field.value() if data: - filter = self.filter(field_name=self.to_field_name or 'pk', queryset=self.queryset) - self.queryset = filter.filter(self.queryset, data) + self.queryset = self.filter_queryset(data) else: self.queryset = self.queryset.none() @@ -638,11 +641,16 @@ class DynamicModelMultipleChoiceField(DynamicModelChoiceMixin, forms.ModelMultip filter = django_filters.ModelMultipleChoiceFilter widget = APISelectMultiple - def _get_initial_value(self, initial_data, field_name): - # If a QueryDict has been passed as initial form data, get *all* listed values - if hasattr(initial_data, 'getlist'): - return initial_data.getlist(field_name) - return initial_data.get(field_name) + def filter_queryset(self, data): + field_name = getattr(self, 'to_field_name') or 'pk' + # Normalize data to a list + if type(data) not in (list, tuple): + data = [data] + filter = self.filter( + field_name=field_name, + lookup_expr='in' + ) + return filter.filter(self.queryset, data) class LaxURLField(forms.URLField): diff --git a/netbox/utilities/tests/test_utils.py b/netbox/utilities/tests/test_utils.py index 5d9a98ad5..0a0c3ad2c 100644 --- a/netbox/utilities/tests/test_utils.py +++ b/netbox/utilities/tests/test_utils.py @@ -1,15 +1,13 @@ +from django.http import QueryDict from django.test import TestCase -from utilities.utils import deepmerge, dict_to_filter_params +from utilities.utils import deepmerge, dict_to_filter_params, normalize_querydict class DictToFilterParamsTest(TestCase): """ Validate the operation of dict_to_filter_params(). """ - def setUp(self): - return - def test_dict_to_filter_params(self): input = { @@ -39,13 +37,21 @@ class DictToFilterParamsTest(TestCase): self.assertNotEqual(dict_to_filter_params(input), output) +class NormalizeQueryDictTest(TestCase): + """ + Validate normalize_querydict() utility function. + """ + def test_normalize_querydict(self): + self.assertDictEqual( + normalize_querydict(QueryDict('foo=1&bar=2&bar=3&baz=')), + {'foo': '1', 'bar': ['2', '3'], 'baz': ''} + ) + + class DeepMergeTest(TestCase): """ Validate the behavior of the deepmerge() utility. """ - def setUp(self): - return - def test_deepmerge(self): dict1 = { diff --git a/netbox/utilities/utils.py b/netbox/utilities/utils.py index 4c07f5520..cb44a93b1 100644 --- a/netbox/utilities/utils.py +++ b/netbox/utilities/utils.py @@ -150,6 +150,24 @@ def dict_to_filter_params(d, prefix=''): return params +def normalize_querydict(querydict): + """ + Convert a QueryDict to a normal, mutable dictionary, preserving list values. For example, + + QueryDict('foo=1&bar=2&bar=3&baz=') + + becomes: + + {'foo': '1', 'bar': ['2', '3'], 'baz': ''} + + This function is necessary because QueryDict does not provide any built-in mechanism which preserves multiple + values. + """ + return { + k: v if len(v) > 1 else v[0] for k, v in querydict.lists() + } + + def deepmerge(original, new): """ Deep merge two dictionaries (new into original) and return a new dict diff --git a/netbox/utilities/views.py b/netbox/utilities/views.py index a4ed54b03..38fb6d963 100644 --- a/netbox/utilities/views.py +++ b/netbox/utilities/views.py @@ -27,7 +27,7 @@ from extras.models import CustomField, CustomFieldValue, ExportTemplate from extras.querysets import CustomFieldQueryset from utilities.exceptions import AbortTransaction from utilities.forms import BootstrapMixin, CSVDataField, TableConfigForm -from utilities.utils import csv_format, prepare_cloned_fields +from utilities.utils import csv_format, normalize_querydict, prepare_cloned_fields from .error_handlers import handle_protectederror from .forms import ConfirmationForm, ImportForm from .paginator import EnhancedPaginator, get_paginate_count @@ -250,7 +250,7 @@ class ObjectEditView(GetReturnURLMixin, View): def get(self, request, *args, **kwargs): # Parse initial data manually to avoid setting field values as lists - initial_data = {k: request.GET[k] for k in request.GET} + initial_data = normalize_querydict(request.GET) form = self.model_form(instance=self.obj, initial=initial_data) return render(request, self.template_name, {