From 34a17d457194d5da5101586d0b965b00897e58ac Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Fri, 1 May 2020 12:18:04 -0400 Subject: [PATCH] Enable the specifcation of related objects by arbitrary attribute during CSV import --- netbox/utilities/forms.py | 51 +++++++++++++++++++++++++++------------ netbox/utilities/views.py | 22 ++++++++++++++--- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/netbox/utilities/forms.py b/netbox/utilities/forms.py index d95c86527..5bff7ad61 100644 --- a/netbox/utilities/forms.py +++ b/netbox/utilities/forms.py @@ -405,10 +405,11 @@ class CSVDataField(forms.CharField): """ widget = forms.Textarea - def __init__(self, fields, required_fields=[], *args, **kwargs): + def __init__(self, model, fields, required_fields=None, *args, **kwargs): + self.model = model self.fields = fields - self.required_fields = required_fields + self.required_fields = required_fields or list() super().__init__(*args, **kwargs) @@ -423,31 +424,49 @@ class CSVDataField(forms.CharField): 'in double quotes.' def to_python(self, value): - records = [] reader = csv.reader(StringIO(value)) - # Consume and validate the first line of CSV data as column headers - headers = next(reader) - for f in self.required_fields: - if f not in headers: - raise forms.ValidationError('Required column header "{}" not found.'.format(f)) - for f in headers: - if f not in self.fields: - raise forms.ValidationError('Unexpected column header "{}" found.'.format(f)) + # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional + # "to" field specifying how the related object is being referenced. For example, importing a Device might use a + # `site.slug` header, to indicate the related site is being referenced by its slug. + headers = {} + for header in next(reader): + if '.' in header: + field, to_field = header.split('.', 1) + headers[field] = to_field + else: + headers[header] = None # Parse CSV data for i, row in enumerate(reader, start=1): if row: if len(row) != len(headers): - raise forms.ValidationError( - "Row {}: Expected {} columns but found {}".format(i, len(headers), len(row)) - ) + raise forms.ValidationError(f"Row {i}: Expected {len(headers)} columns but found {len(row)}") row = [col.strip() for col in row] - record = dict(zip(headers, row)) + record = dict(zip(headers.keys(), row)) records.append(record) - return records + return headers, records + + def validate(self, value): + headers, records = value + + # Validate provided column headers + for field, to_field in headers.items(): + if field not in self.fields: + raise forms.ValidationError(f'Unexpected column header "{field}" found.') + if to_field and not hasattr(self.fields[field], 'to_field_name'): + raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots') + if to_field and not hasattr(self.fields[field].queryset.model, to_field): + raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}') + + # Validate required fields + for f in self.required_fields: + if f not in headers: + raise forms.ValidationError(f'Required column header "{f}" not found.') + + return value class CSVChoiceField(forms.ChoiceField): diff --git a/netbox/utilities/views.py b/netbox/utilities/views.py index 294acb1d1..b1f74a9c6 100644 --- a/netbox/utilities/views.py +++ b/netbox/utilities/views.py @@ -557,11 +557,18 @@ class BulkImportView(GetReturnURLMixin, View): def _import_form(self, *args, **kwargs): - fields = self.model_form().fields.keys() - required_fields = [name for name, field in self.model_form().fields.items() if field.required] + fields = self.model_form().fields + required_fields = [ + name for name, field in self.model_form().fields.items() if field.required + ] class ImportForm(BootstrapMixin, Form): - csv = CSVDataField(fields=fields, required_fields=required_fields, widget=Textarea(attrs=self.widget_attrs)) + csv = CSVDataField( + model=self.model_form.Meta.model, + fields=fields, + required_fields=required_fields, + widget=Textarea(attrs=self.widget_attrs) + ) return ImportForm(*args, **kwargs) @@ -591,8 +598,15 @@ class BulkImportView(GetReturnURLMixin, View): try: # Iterate through CSV data and bind each row to a new model form instance. with transaction.atomic(): - for row, data in enumerate(form.cleaned_data['csv'], start=1): + headers, records = form.cleaned_data['csv'] + for row, data in enumerate(records, start=1): obj_form = self.model_form(data) + + # Modify the model form to accommodate any customized to_field_name properties + for field, to_field in headers.items(): + if to_field is not None: + obj_form.fields[field].to_field_name = to_field + if obj_form.is_valid(): obj = self._save_obj(obj_form, request) new_objs.append(obj)