diff --git a/netbox/utilities/forms.py b/netbox/utilities/forms.py index 5bff7ad61..61ab28ec8 100644 --- a/netbox/utilities/forms.py +++ b/netbox/utilities/forms.py @@ -405,11 +405,14 @@ class CSVDataField(forms.CharField): """ widget = forms.Textarea - def __init__(self, model, fields, required_fields=None, *args, **kwargs): + def __init__(self, from_form, *args, **kwargs): - self.model = model - self.fields = fields - self.required_fields = required_fields or list() + form = from_form() + self.model = form.Meta.model + self.fields = form.fields + self.required_fields = [ + name for name, field in form.fields.items() if field.required + ] super().__init__(*args, **kwargs) @@ -417,15 +420,16 @@ class CSVDataField(forms.CharField): if not self.label: self.label = '' if not self.initial: - self.initial = ','.join(required_fields) + '\n' + self.initial = ','.join(self.required_fields) + '\n' if not self.help_text: self.help_text = 'Enter the list of column headers followed by one line per record to be imported, using ' \ 'commas to separate values. Multi-line data and values containing commas may be wrapped ' \ 'in double quotes.' def to_python(self, value): + records = [] - reader = csv.reader(StringIO(value)) + reader = csv.reader(StringIO(value.strip())) # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional # "to" field specifying how the related object is being referenced. For example, importing a Device might use a @@ -440,12 +444,11 @@ class CSVDataField(forms.CharField): # Parse CSV data for i, row in enumerate(reader, start=1): - if row: - if len(row) != len(headers): - raise forms.ValidationError(f"Row {i}: Expected {len(headers)} columns but found {len(row)}") - row = [col.strip() for col in row] - record = dict(zip(headers.keys(), row)) - records.append(record) + if len(row) != len(headers): + raise forms.ValidationError(f"Row {i}: Expected {len(headers)} columns but found {len(row)}") + row = [col.strip() for col in row] + record = dict(zip(headers.keys(), row)) + records.append(record) return headers, records diff --git a/netbox/utilities/tests/test_forms.py b/netbox/utilities/tests/test_forms.py index 2d7235505..d6af27b93 100644 --- a/netbox/utilities/tests/test_forms.py +++ b/netbox/utilities/tests/test_forms.py @@ -1,6 +1,8 @@ from django import forms from django.test import TestCase +from ipam.forms import IPAddressCSVForm +from ipam.models import VRF from utilities.forms import * @@ -281,3 +283,85 @@ class ExpandAlphanumeric(TestCase): with self.assertRaises(ValueError): sorted(expand_alphanumeric_pattern('r[a,,b]a')) + + +class CSVDataFieldTest(TestCase): + + def setUp(self): + self.field = CSVDataField(from_form=IPAddressCSVForm) + + def test_clean(self): + input = """ + address,status,vrf + 192.0.2.1/32,Active,Test VRF + """ + output = ( + {'address': None, 'status': None, 'vrf': None}, + [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': 'Test VRF'}] + ) + self.assertEqual(self.field.clean(input), output) + + def test_clean_invalid_header(self): + input = """ + address,status,vrf,xxx + 192.0.2.1/32,Active,Test VRF,123 + """ + with self.assertRaises(forms.ValidationError): + self.field.clean(input) + + def test_clean_missing_required_header(self): + input = """ + status,vrf + Active,Test VRF + """ + with self.assertRaises(forms.ValidationError): + self.field.clean(input) + + def test_clean_default_to_field(self): + input = """ + address,status,vrf.name + 192.0.2.1/32,Active,Test VRF + """ + output = ( + {'address': None, 'status': None, 'vrf': 'name'}, + [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': 'Test VRF'}] + ) + self.assertEqual(self.field.clean(input), output) + + def test_clean_pk_to_field(self): + input = """ + address,status,vrf.pk + 192.0.2.1/32,Active,123 + """ + output = ( + {'address': None, 'status': None, 'vrf': 'pk'}, + [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': '123'}] + ) + self.assertEqual(self.field.clean(input), output) + + def test_clean_custom_to_field(self): + input = """ + address,status,vrf.rd + 192.0.2.1/32,Active,123:456 + """ + output = ( + {'address': None, 'status': None, 'vrf': 'rd'}, + [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': '123:456'}] + ) + self.assertEqual(self.field.clean(input), output) + + def test_clean_invalid_to_field(self): + input = """ + address,status,vrf.xxx + 192.0.2.1/32,Active,123:456 + """ + with self.assertRaises(forms.ValidationError): + self.field.clean(input) + + def test_clean_to_field_on_non_object(self): + input = """ + address,status.foo,vrf + 192.0.2.1/32,Bar,Test VRF + """ + with self.assertRaises(forms.ValidationError): + self.field.clean(input) diff --git a/netbox/utilities/views.py b/netbox/utilities/views.py index b1f74a9c6..964d9490c 100644 --- a/netbox/utilities/views.py +++ b/netbox/utilities/views.py @@ -557,16 +557,9 @@ class BulkImportView(GetReturnURLMixin, View): def _import_form(self, *args, **kwargs): - fields = self.model_form().fields - required_fields = [ - name for name, field in self.model_form().fields.items() if field.required - ] - class ImportForm(BootstrapMixin, Form): csv = CSVDataField( - model=self.model_form.Meta.model, - fields=fields, - required_fields=required_fields, + from_form=self.model_form, widget=Textarea(attrs=self.widget_attrs) )