diff --git a/netbox/dcim/views.py b/netbox/dcim/views.py index 13e8354aa..437162bce 100644 --- a/netbox/dcim/views.py +++ b/netbox/dcim/views.py @@ -807,7 +807,7 @@ class RackReservationImportView(generic.BulkImportView): model_form = forms.RackReservationCSVForm table = tables.RackReservationTable - def _save_obj(self, obj_form, request): + def save_object(self, obj_form, request): """ Assign the currently authenticated user to the RackReservation. """ @@ -1082,7 +1082,7 @@ class DeviceTypeInventoryItemsView(DeviceTypeComponentsView): ) -class DeviceTypeImportView(generic.ObjectImportView): +class DeviceTypeImportView(generic.BulkImportView): additional_permissions = [ 'dcim.add_devicetype', 'dcim.add_consoleporttemplate', @@ -1098,6 +1098,7 @@ class DeviceTypeImportView(generic.ObjectImportView): ] queryset = DeviceType.objects.all() model_form = forms.DeviceTypeImportForm + table = tables.DeviceTypeTable related_object_forms = { 'console-ports': forms.ConsolePortTemplateImportForm, 'console-server-ports': forms.ConsoleServerPortTemplateImportForm, @@ -1267,7 +1268,7 @@ class ModuleTypeRearPortsView(ModuleTypeComponentsView): ) -class ModuleTypeImportView(generic.ObjectImportView): +class ModuleTypeImportView(generic.BulkImportView): additional_permissions = [ 'dcim.add_moduletype', 'dcim.add_consoleporttemplate', @@ -1280,6 +1281,7 @@ class ModuleTypeImportView(generic.ObjectImportView): ] queryset = ModuleType.objects.all() model_form = forms.ModuleTypeImportForm + table = tables.ModuleTypeTable related_object_forms = { 'console-ports': forms.ConsolePortTemplateImportForm, 'console-server-ports': forms.ConsoleServerPortTemplateImportForm, @@ -2026,8 +2028,7 @@ class ChildDeviceBulkImportView(generic.BulkImportView): table = tables.DeviceImportTable template_name = 'dcim/device_import_child.html' - def _save_obj(self, obj_form, request): - + def save_object(self, obj_form, request): obj = obj_form.save() # Save the reverse relation to the parent device bay diff --git a/netbox/extras/tests/test_customfields.py b/netbox/extras/tests/test_customfields.py index 7e7eaeda0..2f3c7932a 100644 --- a/netbox/extras/tests/test_customfields.py +++ b/netbox/extras/tests/test_customfields.py @@ -935,7 +935,7 @@ class CustomFieldImportTest(TestCase): ) csv_data = '\n'.join(','.join(row) for row in data) - response = self.client.post(reverse('dcim:site_import'), {'csv': csv_data}) + response = self.client.post(reverse('dcim:site_import'), {'data': csv_data, 'format': 'csv'}) self.assertEqual(response.status_code, 200) self.assertEqual(Site.objects.count(), 3) diff --git a/netbox/ipam/tests/test_views.py b/netbox/ipam/tests/test_views.py index 25b8af9ae..8bf19ebfa 100644 --- a/netbox/ipam/tests/test_views.py +++ b/netbox/ipam/tests/test_views.py @@ -920,7 +920,11 @@ class L2VPNTerminationTestCase( def setUpTestData(cls): device = create_test_device('Device 1') interface = Interface.objects.create(name='Interface 1', device=device, type='1000baset') - l2vpn = L2VPN.objects.create(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=650001) + l2vpns = ( + L2VPN(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=650001), + L2VPN(name='L2VPN 2', slug='l2vpn-2', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=650002), + ) + L2VPN.objects.bulk_create(l2vpns) vlans = ( VLAN(name='Vlan 1', vid=1001), @@ -933,14 +937,14 @@ class L2VPNTerminationTestCase( VLAN.objects.bulk_create(vlans) terminations = ( - L2VPNTermination(l2vpn=l2vpn, assigned_object=vlans[0]), - L2VPNTermination(l2vpn=l2vpn, assigned_object=vlans[1]), - L2VPNTermination(l2vpn=l2vpn, assigned_object=vlans[2]) + L2VPNTermination(l2vpn=l2vpns[0], assigned_object=vlans[0]), + L2VPNTermination(l2vpn=l2vpns[0], assigned_object=vlans[1]), + L2VPNTermination(l2vpn=l2vpns[0], assigned_object=vlans[2]) ) L2VPNTermination.objects.bulk_create(terminations) cls.form_data = { - 'l2vpn': l2vpn.pk, + 'l2vpn': l2vpns[0].pk, 'device': device.pk, 'interface': interface.pk, } @@ -953,10 +957,10 @@ class L2VPNTerminationTestCase( ) cls.csv_update_data = ( - "id,l2vpn", - f"{terminations[0].pk},L2VPN 2", - f"{terminations[1].pk},L2VPN 2", - f"{terminations[2].pk},L2VPN 2", + f"id,l2vpn", + f"{terminations[0].pk},{l2vpns[0].name}", + f"{terminations[1].pk},{l2vpns[0].name}", + f"{terminations[2].pk},{l2vpns[0].name}", ) cls.bulk_edit_data = {} diff --git a/netbox/netbox/tests/test_import.py b/netbox/netbox/tests/test_import.py index 73f2e0e27..b6f732bfe 100644 --- a/netbox/netbox/tests/test_import.py +++ b/netbox/netbox/tests/test_import.py @@ -3,6 +3,7 @@ from django.test import override_settings from dcim.models import * from users.models import ObjectPermission +from utilities.forms.choices import ImportFormatChoices from utilities.testing import ModelViewTestCase, create_tags @@ -27,7 +28,8 @@ class CSVImportTestCase(ModelViewTestCase): ) data = { - 'csv': self._get_csv_data(csv_data), + 'format': ImportFormatChoices.CSV, + 'data': self._get_csv_data(csv_data), } # Assign model-level permission @@ -67,7 +69,8 @@ class CSVImportTestCase(ModelViewTestCase): ) data = { - 'csv': self._get_csv_data(csv_data), + 'format': ImportFormatChoices.CSV, + 'data': self._get_csv_data(csv_data), } # Assign model-level permission diff --git a/netbox/netbox/views/generic/bulk_views.py b/netbox/netbox/views/generic/bulk_views.py index 5ab9e6da0..1a83c9de2 100644 --- a/netbox/netbox/views/generic/bulk_views.py +++ b/netbox/netbox/views/generic/bulk_views.py @@ -4,23 +4,22 @@ from copy import deepcopy from django.contrib import messages from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError from django.db import transaction, IntegrityError from django.db.models import ManyToManyField, ProtectedError from django.db.models.fields.reverse_related import ManyToManyRel -from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput +from django.forms import ModelMultipleChoiceField, MultipleHiddenInput from django.http import HttpResponse from django.shortcuts import get_object_or_404, redirect, render from django.utils.safestring import mark_safe from django_tables2.export import TableExport -from extras.models import ExportTemplate, SavedFilter +from extras.models import ExportTemplate from extras.signals import clear_webhooks from utilities.error_handlers import handle_protectederror -from utilities.exceptions import AbortRequest, PermissionsViolation -from utilities.forms import ( - BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, CSVFileField, restrict_form_fields, -) +from utilities.exceptions import AbortRequest, AbortTransaction, PermissionsViolation +from utilities.forms import BulkRenameForm, ConfirmationForm, ImportForm, restrict_form_fields +from utilities.forms.choices import ImportFormatChoices from utilities.htmx import is_htmx from utilities.permissions import get_permission_for_model from utilities.views import GetReturnURLMixin @@ -295,109 +294,136 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): """ template_name = 'generic/bulk_import.html' model_form = None + related_object_forms = dict() - def _import_form(self, *args, **kwargs): + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'add') - class ImportForm(BootstrapMixin, Form): - csv = CSVDataField( - from_form=self.model_form - ) - csv_file = CSVFileField( - label="CSV file", - from_form=self.model_form, - required=False - ) + def prep_related_object_data(self, parent, data): + """ + Hook to modify the data for related objects before it's passed to the related object form (for example, to + assign a parent object). + """ + return data - def clean(self): - csv_rows = self.cleaned_data['csv'][1] if 'csv' in self.cleaned_data else None - csv_file = self.files.get('csv_file') + def _save_object(self, model_form, request): - # Check that the user has not submitted both text data and a file - if csv_rows and csv_file: - raise ValidationError( - "Cannot process CSV text and file attachment simultaneously. Please choose only one import " - "method." - ) + # Save the primary object + obj = self.save_object(model_form, request) - return ImportForm(*args, **kwargs) + # Enforce object-level permissions + if not self.queryset.filter(pk=obj.pk).first(): + raise PermissionsViolation() - def _get_records(self, form, request): - if request.FILES: - headers, records = form.cleaned_data['csv_file'] - else: - headers, records = form.cleaned_data['csv'] + # Iterate through the related object forms (if any), validating and saving each instance. + for field_name, related_object_form in self.related_object_forms.items(): - return headers, records + related_obj_pks = [] + for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())): + rel_obj_data = self.prep_related_object_data(obj, rel_obj_data) + f = related_object_form(rel_obj_data) - def _update_objects(self, form, request, headers, records): - updated_objs = [] + for subfield_name, field in f.fields.items(): + if subfield_name not in rel_obj_data and hasattr(field, 'initial'): + f.data[subfield_name] = field.initial - ids = [int(record["id"]) for record in records] - qs = self.queryset.model.objects.filter(id__in=ids) - objs = {} - for obj in qs: - objs[obj.id] = obj + if f.is_valid(): + related_obj = f.save() + related_obj_pks.append(related_obj.pk) + else: + # Replicate errors on the related object form to the primary form for display + for subfield_name, errors in f.errors.items(): + for err in errors: + err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err) + model_form.add_error(None, err_msg) + raise AbortTransaction() - for row, data in enumerate(records, start=1): - if int(data["id"]) not in objs: - form.add_error('csv', f'Row {row} id: {data["id"]} Does not exist') - raise ValidationError("") + # Enforce object-level permissions on related objects + model = related_object_form.Meta.model + if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): + raise ObjectDoesNotExist - obj = objs[int(data["id"])] - obj_form = self.model_form(data, headers=headers, instance=obj) + return obj - # The form should only contain fields that are in the CSV - for name, field in list(obj_form.fields.items()): - if name not in headers: - del obj_form.fields[name] - - restrict_form_fields(obj_form, request.user) - - if obj_form.is_valid(): - obj = self._save_obj(obj_form, request) - updated_objs.append(obj) - else: - for field, err in obj_form.errors.items(): - form.add_error('csv', f'Row {row} {field}: {err[0]}') - raise ValidationError("") - - return updated_objs - - def _create_objects(self, form, request, headers, records): - new_objs = [] - - for row, data in enumerate(records, start=1): - obj_form = self.model_form(data, headers=headers) - restrict_form_fields(obj_form, request.user) - - if obj_form.is_valid(): - obj = self._save_obj(obj_form, request) - new_objs.append(obj) - else: - for field, err in obj_form.errors.items(): - form.add_error('csv', f'Row {row} {field}: {err[0]}') - raise ValidationError("") - - return new_objs - - def _save_obj(self, obj_form, request): + def save_object(self, obj_form, request): """ Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data). """ return obj_form.save() - def get_required_permission(self): - return get_permission_for_model(self.queryset.model, 'add') + def create_and_update_objects(self, form, request): + saved_objects = [] + + records = list(form.cleaned_data['data']) + + # Prefetch objects to be updated, if any + prefetch_ids = [int(record['id']) for record in records if record.get('id')] + prefetched_objects = { + obj.pk: obj + for obj in self.queryset.model.objects.filter(id__in=prefetch_ids) + } if prefetch_ids else {} + + for i, record in enumerate(records, start=1): + instance = None + object_id = int(record.pop('id')) if record.get('id') else None + + # Determine whether this object is being created or updated + if object_id: + try: + instance = prefetched_objects[object_id] + except KeyError: + form.add_error('data', f"Row {i}: Object with ID {object_id} does not exist") + raise ValidationError('') + + if form.cleaned_data['format'] == ImportFormatChoices.CSV: + model_form = self.model_form(record, instance=instance, headers=form._csv_headers) + else: + model_form = self.model_form(record, instance=instance) + # Assign default values for any fields which were not specified. + # We have to do this manually because passing 'initial=' to the form + # on initialization merely sets default values for the widgets. + # Since widgets are not used for YAML/JSON import, we first bind the + # imported data normally, then update the form's data with the applicable + # field defaults as needed prior to form validation. + for field_name, field in model_form.fields.items(): + if field_name not in record and hasattr(field, 'initial'): + model_form.data[field_name] = field.initial + + # When updating, omit all form fields other than those specified in the record. (No + # fields are required when modifying an existing object.) + if object_id: + unused_fields = [f for f in model_form.fields if f not in record] + for field_name in unused_fields: + del model_form.fields[field_name] + + restrict_form_fields(model_form, request.user) + + if model_form.is_valid(): + obj = self._save_object(model_form, request) + saved_objects.append(obj) + else: + # Replicate model form errors for display + for field, errors in model_form.errors.items(): + for err in errors: + if field == '__all__': + form.add_error(None, f'Record {i}: {err}') + else: + form.add_error(None, f'Record {i} {field}: {err}') + + raise ValidationError("") + + return saved_objects # # Request handlers # def get(self, request): + form = ImportForm() return render(request, self.template_name, { 'model': self.model_form._meta.model, - 'form': self._import_form(), + 'form': form, 'fields': self.model_form().fields, 'return_url': self.get_return_url(request), **self.get_extra_context(request), @@ -405,19 +431,16 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): def post(self, request): logger = logging.getLogger('netbox.views.BulkImportView') - form = self._import_form(request.POST, request.FILES) + + form = ImportForm(request.POST, request.FILES) if form.is_valid(): - logger.debug("Form validation was successful") + logger.debug("Import form validation was successful") try: - # Iterate through CSV data and bind each row to a new model form instance. + # Iterate through data and bind each record to a new model form instance. with transaction.atomic(): - headers, records = self._get_records(form, request) - if "id" in headers: - new_objs = self._update_objects(form, request, headers, records) - else: - new_objs = self._create_objects(form, request, headers, records) + new_objs = self.create_and_update_objects(form, request) # Enforce object-level permissions if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): diff --git a/netbox/netbox/views/generic/object_views.py b/netbox/netbox/views/generic/object_views.py index 0d122a41a..738d70786 100644 --- a/netbox/netbox/views/generic/object_views.py +++ b/netbox/netbox/views/generic/object_views.py @@ -2,7 +2,6 @@ import logging from copy import deepcopy from django.contrib import messages -from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import ProtectedError from django.shortcuts import redirect, render @@ -12,8 +11,8 @@ from django.utils.safestring import mark_safe from extras.signals import clear_webhooks from utilities.error_handlers import handle_protectederror -from utilities.exceptions import AbortRequest, AbortTransaction, PermissionsViolation -from utilities.forms import ConfirmationForm, ImportForm, restrict_form_fields +from utilities.exceptions import AbortRequest, PermissionsViolation +from utilities.forms import ConfirmationForm, restrict_form_fields from utilities.htmx import is_htmx from utilities.permissions import get_permission_for_model from utilities.utils import get_viewname, normalize_querydict, prepare_cloned_fields @@ -27,7 +26,6 @@ __all__ = ( 'ObjectChildrenView', 'ObjectDeleteView', 'ObjectEditView', - 'ObjectImportView', 'ObjectView', ) @@ -151,146 +149,6 @@ class ObjectChildrenView(ObjectView, ActionsMixin, TableMixin): }) -class ObjectImportView(GetReturnURLMixin, BaseObjectView): - """ - Import a single object (YAML or JSON format). - - Attributes: - model_form: The ModelForm used to create individual objects - related_object_forms: A dictionary mapping of forms to be used for the creation of related (child) objects - """ - template_name = 'generic/object_import.html' - model_form = None - related_object_forms = dict() - - def get_required_permission(self): - return get_permission_for_model(self.queryset.model, 'add') - - def prep_related_object_data(self, parent, data): - """ - Hook to modify the data for related objects before it's passed to the related object form (for example, to - assign a parent object). - """ - return data - - def _create_object(self, model_form): - - # Save the primary object - obj = model_form.save() - - # Enforce object-level permissions - if not self.queryset.filter(pk=obj.pk).exists(): - raise PermissionsViolation() - - # Iterate through the related object forms (if any), validating and saving each instance. - for field_name, related_object_form in self.related_object_forms.items(): - - related_obj_pks = [] - for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())): - rel_obj_data = self.prep_related_object_data(obj, rel_obj_data) - f = related_object_form(rel_obj_data) - - for subfield_name, field in f.fields.items(): - if subfield_name not in rel_obj_data and hasattr(field, 'initial'): - f.data[subfield_name] = field.initial - - if f.is_valid(): - related_obj = f.save() - related_obj_pks.append(related_obj.pk) - else: - # Replicate errors on the related object form to the primary form for display - for subfield_name, errors in f.errors.items(): - for err in errors: - err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err) - model_form.add_error(None, err_msg) - raise AbortTransaction() - - # Enforce object-level permissions on related objects - model = related_object_form.Meta.model - if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): - raise ObjectDoesNotExist - - return obj - - # - # Request handlers - # - - def get(self, request): - form = ImportForm() - - return render(request, self.template_name, { - 'form': form, - 'obj_type': self.queryset.model._meta.verbose_name, - 'return_url': self.get_return_url(request), - }) - - def post(self, request): - logger = logging.getLogger('netbox.views.ObjectImportView') - form = ImportForm(request.POST) - - if form.is_valid(): - logger.debug("Import form validation was successful") - - # Initialize model form - data = form.cleaned_data['data'] - model_form = self.model_form(data) - restrict_form_fields(model_form, request.user) - - # Assign default values for any fields which were not specified. We have to do this manually because passing - # 'initial=' to the form on initialization merely sets default values for the widgets. Since widgets are not - # used for YAML/JSON import, we first bind the imported data normally, then update the form's data with the - # applicable field defaults as needed prior to form validation. - for field_name, field in model_form.fields.items(): - if field_name not in data and hasattr(field, 'initial'): - model_form.data[field_name] = field.initial - - if model_form.is_valid(): - - try: - with transaction.atomic(): - obj = self._create_object(model_form) - - except AbortTransaction: - clear_webhooks.send(sender=self) - - except (AbortRequest, PermissionsViolation) as e: - logger.debug(e.message) - form.add_error(None, e.message) - clear_webhooks.send(sender=self) - - if not model_form.errors: - logger.info(f"Import object {obj} (PK: {obj.pk})") - msg = f'Imported object: {obj}' - messages.success(request, mark_safe(msg)) - - if '_addanother' in request.POST: - return redirect(request.get_full_path()) - - self.get_return_url(request, obj) - return redirect(self.get_return_url(request, obj)) - - else: - logger.debug("Model form validation failed") - - # Replicate model form errors for display - for field, errors in model_form.errors.items(): - for err in errors: - if field == '__all__': - form.add_error(None, err) - else: - form.add_error(None, "{}: {}".format(field, err)) - - else: - logger.debug("Import form validation failed") - - return render(request, self.template_name, { - 'form': form, - 'obj_type': self.queryset.model._meta.verbose_name, - 'return_url': self.get_return_url(request), - }) - - class ObjectEditView(GetReturnURLMixin, BaseObjectView): """ Create or edit a single object. diff --git a/netbox/templates/generic/bulk_import.html b/netbox/templates/generic/bulk_import.html index 1d638cb2c..4ddfb884c 100644 --- a/netbox/templates/generic/bulk_import.html +++ b/netbox/templates/generic/bulk_import.html @@ -15,142 +15,160 @@ Context: {% block tabs %} {% endblock tabs %} {% block content-wrapper %}
- {% block content %} -
-
- -
- {% csrf_token %} -
-
- {% render_field form.csv %} + + {# Data Import Form #} +
+ {% block content %} +
+
+ + {% csrf_token %} + {% render_field form.data %} + {% render_field form.format %} +
+
+ + {% if return_url %} + Cancel + {% endif %} +
-
- {% render_field form.csv_file %} + +
+
+ {% endblock content %} +
+ + {# File Upload Form #} +
+
+
+ {% csrf_token %} + {% render_field form.data_file %} + {% render_field form.format %} +
+
+ + {% if return_url %} + Cancel + {% endif %} +
-
-
-
- - {% if return_url %} - Cancel - {% endif %} -
-
- - {% if fields %} -
-
-
-
- CSV Field Options -
-
- - - - - - - - {% for name, field in fields.items %} - - - - - + + {% endfor %} +
FieldRequiredAccessorDescription
- {{ name }} - - {% if field.required %} - {% checkmark True true="Required" %} - {% else %} - {{ ''|placeholder }} - {% endif %} - - {% if field.to_field_name %} - {{ field.to_field_name }} - {% else %} - {{ ''|placeholder }} - {% endif %} - - {% if field.STATIC_CHOICES %} - - + + + {% if fields %} +
+
+
+
+ Field Options +
+
+ + + + + + + + {% for name, field in fields.items %} + + + + + - - {% endfor %} -
FieldRequiredAccessorDescription
+ {% if field.required %}{% endif %}{{ name }}{% if field.required %}{% endif %} + + {% if field.required %} + {% checkmark True true="Required" %} + {% else %} + {{ ''|placeholder }} + {% endif %} + + {% if field.to_field_name %} + {{ field.to_field_name }} + {% else %} + {{ ''|placeholder }} + {% endif %} + + {% if field.STATIC_CHOICES %} + +
-
-
-
+ +
+ + + {% endif %} + {% if field.help_text %} + {{ field.help_text }}
+ {% elif field.label %} + {{ field.label }}
+ {% endif %} + {% if field|widget_type == 'dateinput' %} + Format: YYYY-MM-DD + {% elif field|widget_type == 'checkboxinput' %} + Specify "true" or "false" + {% endif %} +
-

- Required fields must be specified for all - objects. -

-

- Related objects may be referenced by any unique attribute. - For example, vrf.rd would identify a VRF by its route distinguisher. -

- {% endif %}
- {% endblock content %} +
+

+ Required fields must be specified for all + objects. +

+

+ Related objects may be referenced by any unique attribute. + For example, vrf.rd would identify a VRF by its route distinguisher. +

+ {% endif %} +
{% endblock content-wrapper %} diff --git a/netbox/utilities/forms/choices.py b/netbox/utilities/forms/choices.py new file mode 100644 index 000000000..bf0ea5f94 --- /dev/null +++ b/netbox/utilities/forms/choices.py @@ -0,0 +1,17 @@ +from utilities.choices import ChoiceSet + + +# +# Import Choices +# + +class ImportFormatChoices(ChoiceSet): + CSV = 'csv' + JSON = 'json' + YAML = 'yaml' + + CHOICES = [ + (CSV, 'CSV'), + (JSON, 'JSON'), + (YAML, 'YAML'), + ] diff --git a/netbox/utilities/forms/forms.py b/netbox/utilities/forms/forms.py index 8ad6f103b..0569853b8 100644 --- a/netbox/utilities/forms/forms.py +++ b/netbox/utilities/forms/forms.py @@ -1,12 +1,15 @@ +import csv import json import re +from io import StringIO import yaml from django import forms +from utilities.forms.utils import parse_csv +from .choices import ImportFormatChoices from .widgets import APISelect, APISelectMultiple, ClearableFileInput, StaticSelect - __all__ = ( 'BootstrapMixin', 'BulkEditForm', @@ -120,64 +123,94 @@ class CSVModelForm(forms.ModelForm): """ ModelForm used for the import of objects in CSV format. """ - - def __init__(self, *args, headers=None, **kwargs): + def __init__(self, *args, headers=None, fields=None, **kwargs): + headers = headers or {} + fields = fields or [] super().__init__(*args, **kwargs) # Modify the model form to accommodate any customized to_field_name properties - if headers: - for field, to_field in headers.items(): - if to_field is not None: - self.fields[field].to_field_name = to_field + for field, to_field in headers.items(): + if to_field is not None: + self.fields[field].to_field_name = to_field + + # Omit any fields not specified (e.g. because the form is being used to + # updated rather than create objects) + if fields: + for field in list(self.fields.keys()): + if field not in fields: + del self.fields[field] class ImportForm(BootstrapMixin, forms.Form): - """ - Generic form for creating an object from JSON/YAML data - """ data = forms.CharField( + required=False, widget=forms.Textarea(attrs={'class': 'font-monospace'}), - help_text="Enter object data in JSON or YAML format. Note: Only a single object/document is supported." + help_text="Enter object data in CSV, JSON or YAML format." ) + data_file = forms.FileField( + label="Data file", + required=False + ) + # TODO: Enable auto-detection of format format = forms.ChoiceField( - choices=( - ('json', 'JSON'), - ('yaml', 'YAML') - ), - initial='yaml' + choices=ImportFormatChoices, + initial=ImportFormatChoices.CSV, + widget=StaticSelect() ) + data_field = 'data' + def clean(self): super().clean() - - data = self.cleaned_data['data'] format = self.cleaned_data['format'] - # Process JSON/YAML data - if format == 'json': - try: - self.cleaned_data['data'] = json.loads(data) - # Check for multiple JSON objects - if type(self.cleaned_data['data']) is not dict: - raise forms.ValidationError({ - 'data': "Import is limited to one object at a time." - }) - except json.decoder.JSONDecodeError as err: - raise forms.ValidationError({ - 'data': "Invalid JSON data: {}".format(err) - }) + # Determine whether we're reading from form data or an uploaded file + if self.cleaned_data['data'] and self.cleaned_data['data_file']: + raise forms.ValidationError("Form data must be empty when uploading a file.") + if 'data_file' in self.files: + self.data_field = 'data_file' + file = self.files.get('data_file') + data = file.read().decode('utf-8') else: - # Check for multiple YAML documents - if '\n---' in data: - raise forms.ValidationError({ - 'data': "Import is limited to one object at a time." - }) - try: - self.cleaned_data['data'] = yaml.load(data, Loader=yaml.SafeLoader) - except yaml.error.YAMLError as err: - raise forms.ValidationError({ - 'data': "Invalid YAML data: {}".format(err) - }) + data = self.cleaned_data['data'] + + # Process data according to the selected format + if format == ImportFormatChoices.CSV: + self.cleaned_data['data'] = self._clean_csv(data) + elif format == ImportFormatChoices.JSON: + self.cleaned_data['data'] = self._clean_json(data) + elif format == ImportFormatChoices.YAML: + self.cleaned_data['data'] = self._clean_yaml(data) + + def _clean_csv(self, data): + stream = StringIO(data.strip()) + reader = csv.reader(stream) + headers, records = parse_csv(reader) + + # Set CSV headers for reference by the model form + self._csv_headers = headers + + return records + + def _clean_json(self, data): + try: + data = json.loads(data) + # Accommodate for users entering single objects + if type(data) is not list: + data = [data] + return data + except json.decoder.JSONDecodeError as err: + raise forms.ValidationError({ + self.data_field: f"Invalid JSON data: {err}" + }) + + def _clean_yaml(self, data): + try: + return yaml.load_all(data, Loader=yaml.SafeLoader) + except yaml.error.YAMLError as err: + raise forms.ValidationError({ + self.data_field: f"Invalid YAML data: {err}" + }) class FilterForm(BootstrapMixin, forms.Form): diff --git a/netbox/utilities/testing/views.py b/netbox/utilities/testing/views.py index f51893f74..5e1e207cc 100644 --- a/netbox/utilities/testing/views.py +++ b/netbox/utilities/testing/views.py @@ -9,6 +9,7 @@ from django.urls import reverse from extras.choices import ObjectChangeActionChoices from extras.models import ObjectChange from users.models import ObjectPermission +from utilities.forms.choices import ImportFormatChoices from .base import ModelTestCase from .utils import disable_warnings, post_data @@ -555,7 +556,8 @@ class ViewTestCases: def test_bulk_import_objects_without_permission(self): data = { - 'csv': self._get_csv_data(), + 'data': self._get_csv_data(), + 'format': 'csv', } # Test GET without permission @@ -571,7 +573,8 @@ class ViewTestCases: def test_bulk_import_objects_with_permission(self): initial_count = self._get_queryset().count() data = { - 'csv': self._get_csv_data(), + 'data': self._get_csv_data(), + 'format': 'csv', } # Assign model-level permission @@ -598,7 +601,8 @@ class ViewTestCases: initial_count = self._get_queryset().count() array, csv_data = self._get_update_csv_data() data = { - 'csv': csv_data, + 'format': ImportFormatChoices.CSV, + 'data': csv_data, } # Assign model-level permission @@ -630,7 +634,8 @@ class ViewTestCases: def test_bulk_import_objects_with_constrained_permission(self): initial_count = self._get_queryset().count() data = { - 'csv': self._get_csv_data(), + 'data': self._get_csv_data(), + 'format': 'csv', } # Assign constrained permission