diff --git a/netbox/netbox/views/generic.py b/netbox/netbox/views/generic.py index c43519e01..a312441e5 100644 --- a/netbox/netbox/views/generic.py +++ b/netbox/netbox/views/generic.py @@ -20,7 +20,7 @@ from extras.models import CustomField, ExportTemplate from utilities.error_handlers import handle_protectederror from utilities.exceptions import AbortTransaction, PermissionsViolation from utilities.forms import ( - BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields, + BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields, CSVFileField ) from utilities.permissions import get_permission_for_model from utilities.tables import paginate_table @@ -667,6 +667,14 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): from_form=self.model_form, widget=Textarea(attrs=self.widget_attrs) ) + csv_file = CSVFileField( + label="CSV file", + from_form=self.model_form, + required=False + ) + + def used_both_csv_fields(self): + return self.cleaned_data['csv_file'][1] and self.cleaned_data['csv'][1] return ImportForm(*args, **kwargs) @@ -691,15 +699,21 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): def post(self, request): logger = logging.getLogger('netbox.views.BulkImportView') new_objs = [] - form = self._import_form(request.POST) + form = self._import_form(request.POST, request.FILES) if form.is_valid(): logger.debug("Form validation was successful") try: + if form.used_both_csv_fields(): + form.add_error('csv_file', "Choose one of two import methods") + raise ValidationError("") # Iterate through CSV data and bind each row to a new model form instance. with transaction.atomic(): - headers, records = form.cleaned_data['csv'] + if request.FILES: + headers, records = form.cleaned_data['csv_file'] + else: + headers, records = form.cleaned_data['csv'] for row, data in enumerate(records, start=1): obj_form = self.model_form(data, headers=headers) restrict_form_fields(obj_form, request.user) diff --git a/netbox/templates/generic/object_bulk_import.html b/netbox/templates/generic/object_bulk_import.html index 170cf3665..e9972e919 100644 --- a/netbox/templates/generic/object_bulk_import.html +++ b/netbox/templates/generic/object_bulk_import.html @@ -20,7 +20,7 @@
-
+ {% csrf_token %} {% render_form form %}
diff --git a/netbox/utilities/forms/fields.py b/netbox/utilities/forms/fields.py index 93b9e6c44..2b9b57ea3 100644 --- a/netbox/utilities/forms/fields.py +++ b/netbox/utilities/forms/fields.py @@ -17,7 +17,7 @@ from utilities.utils import content_type_name from utilities.validators import EnhancedURLValidator from . import widgets from .constants import * -from .utils import expand_alphanumeric_pattern, expand_ipaddress_pattern +from .utils import expand_alphanumeric_pattern, expand_ipaddress_pattern, parse_csv, validate_csv __all__ = ( 'CommentField', @@ -26,6 +26,7 @@ __all__ = ( 'CSVChoiceField', 'CSVContentTypeField', 'CSVDataField', + 'CSVFileField', 'CSVModelChoiceField', 'CSVTypedChoiceField', 'DynamicModelChoiceField', @@ -174,49 +175,56 @@ class CSVDataField(forms.CharField): 'in double quotes.' def to_python(self, value): - - records = [] reader = csv.reader(StringIO(value.strip())) - # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional - # "to" field specifying how the related object is being referenced. For example, importing a Device might use a - # `site.slug` header, to indicate the related site is being referenced by its slug. - headers = {} - for header in next(reader): - if '.' in header: - field, to_field = header.split('.', 1) - headers[field] = to_field - else: - headers[header] = None + return parse_csv(reader) - # Parse CSV rows into a list of dictionaries mapped from the column headers. - for i, row in enumerate(reader, start=1): - if len(row) != len(headers): - raise forms.ValidationError( - f"Row {i}: Expected {len(headers)} columns but found {len(row)}" - ) - row = [col.strip() for col in row] - record = dict(zip(headers.keys(), row)) - records.append(record) + def validate(self, value): + headers, records = value + validate_csv(headers, self.fields, self.required_fields) + + return value + + +class CSVFileField(forms.FileField): + """ + A FileField (rendered as a file input button) which accepts a file containing CSV-formatted data. It returns + data as a two-tuple: The first item is a dictionary of column headers, mapping field names to the attribute + by which they match a related object (where applicable). The second item is a list of dictionaries, each + representing a discrete row of CSV data. + + :param from_form: The form from which the field derives its validation rules. + """ + + def __init__(self, from_form, *args, **kwargs): + + form = from_form() + self.model = form.Meta.model + self.fields = form.fields + self.required_fields = [ + name for name, field in form.fields.items() if field.required + ] + + super().__init__(*args, **kwargs) + + def to_python(self, file): + if file: + csv_str = file.read().decode('utf-8') + reader = csv.reader(csv_str.splitlines()) + + headers = {} + records = [] + if file: + headers, records = parse_csv(reader) return headers, records def validate(self, value): headers, records = value + if not headers and not records: + return value - # Validate provided column headers - for field, to_field in headers.items(): - if field not in self.fields: - raise forms.ValidationError(f'Unexpected column header "{field}" found.') - if to_field and not hasattr(self.fields[field], 'to_field_name'): - raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots') - if to_field and not hasattr(self.fields[field].queryset.model, to_field): - raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}') - - # Validate required fields - for f in self.required_fields: - if f not in headers: - raise forms.ValidationError(f'Required column header "{f}" not found.') + validate_csv(headers, self.fields, self.required_fields) return value diff --git a/netbox/utilities/forms/utils.py b/netbox/utilities/forms/utils.py index dc001be1a..50a6a416e 100644 --- a/netbox/utilities/forms/utils.py +++ b/netbox/utilities/forms/utils.py @@ -14,6 +14,8 @@ __all__ = ( 'parse_alphanumeric_range', 'parse_numeric_range', 'restrict_form_fields', + 'parse_csv', + 'validate_csv', ) @@ -134,3 +136,54 @@ def restrict_form_fields(form, user, action='view'): for field in form.fields.values(): if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet): field.queryset = field.queryset.restrict(user, action) + + +def parse_csv(reader): + """ + Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error + if the records are formatted incorrectly. Return headers and records as a tuple. + """ + records = [] + headers = {} + + # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional + # "to" field specifying how the related object is being referenced. For example, importing a Device might use a + # `site.slug` header, to indicate the related site is being referenced by its slug. + + for header in next(reader): + if '.' in header: + field, to_field = header.split('.', 1) + headers[field] = to_field + else: + headers[header] = None + + # Parse CSV rows into a list of dictionaries mapped from the column headers. + for i, row in enumerate(reader, start=1): + if len(row) != len(headers): + raise forms.ValidationError( + f"Row {i}: Expected {len(headers)} columns but found {len(row)}" + ) + row = [col.strip() for col in row] + record = dict(zip(headers.keys(), row)) + records.append(record) + return headers, records + + +def validate_csv(headers, fields, required_fields): + """ + Validate that parsed csv data conforms to the object's available fields. Raise validation errors + if parsed csv data contains invalid headers or does not contain required headers. + """ + # Validate provided column headers + for field, to_field in headers.items(): + if field not in fields: + raise forms.ValidationError(f'Unexpected column header "{field}" found.') + if to_field and not hasattr(fields[field], 'to_field_name'): + raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots') + if to_field and not hasattr(fields[field].queryset.model, to_field): + raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}') + + # Validate required fields + for f in required_fields: + if f not in headers: + raise forms.ValidationError(f'Required column header "{f}" not found.')