diff --git a/netbox/dcim/views.py b/netbox/dcim/views.py index 13e8354aa..437162bce 100644 --- a/netbox/dcim/views.py +++ b/netbox/dcim/views.py @@ -807,7 +807,7 @@ class RackReservationImportView(generic.BulkImportView): model_form = forms.RackReservationCSVForm table = tables.RackReservationTable - def _save_obj(self, obj_form, request): + def save_object(self, obj_form, request): """ Assign the currently authenticated user to the RackReservation. """ @@ -1082,7 +1082,7 @@ class DeviceTypeInventoryItemsView(DeviceTypeComponentsView): ) -class DeviceTypeImportView(generic.ObjectImportView): +class DeviceTypeImportView(generic.BulkImportView): additional_permissions = [ 'dcim.add_devicetype', 'dcim.add_consoleporttemplate', @@ -1098,6 +1098,7 @@ class DeviceTypeImportView(generic.ObjectImportView): ] queryset = DeviceType.objects.all() model_form = forms.DeviceTypeImportForm + table = tables.DeviceTypeTable related_object_forms = { 'console-ports': forms.ConsolePortTemplateImportForm, 'console-server-ports': forms.ConsoleServerPortTemplateImportForm, @@ -1267,7 +1268,7 @@ class ModuleTypeRearPortsView(ModuleTypeComponentsView): ) -class ModuleTypeImportView(generic.ObjectImportView): +class ModuleTypeImportView(generic.BulkImportView): additional_permissions = [ 'dcim.add_moduletype', 'dcim.add_consoleporttemplate', @@ -1280,6 +1281,7 @@ class ModuleTypeImportView(generic.ObjectImportView): ] queryset = ModuleType.objects.all() model_form = forms.ModuleTypeImportForm + table = tables.ModuleTypeTable related_object_forms = { 'console-ports': forms.ConsolePortTemplateImportForm, 'console-server-ports': forms.ConsoleServerPortTemplateImportForm, @@ -2026,8 +2028,7 @@ class ChildDeviceBulkImportView(generic.BulkImportView): table = tables.DeviceImportTable template_name = 'dcim/device_import_child.html' - def _save_obj(self, obj_form, request): - + def save_object(self, obj_form, request): obj = obj_form.save() # Save the reverse relation to the parent device bay diff --git a/netbox/extras/tests/test_customfields.py b/netbox/extras/tests/test_customfields.py index 7e7eaeda0..2f3c7932a 100644 --- a/netbox/extras/tests/test_customfields.py +++ b/netbox/extras/tests/test_customfields.py @@ -935,7 +935,7 @@ class CustomFieldImportTest(TestCase): ) csv_data = '\n'.join(','.join(row) for row in data) - response = self.client.post(reverse('dcim:site_import'), {'csv': csv_data}) + response = self.client.post(reverse('dcim:site_import'), {'data': csv_data, 'format': 'csv'}) self.assertEqual(response.status_code, 200) self.assertEqual(Site.objects.count(), 3) diff --git a/netbox/ipam/tests/test_views.py b/netbox/ipam/tests/test_views.py index 25b8af9ae..8bf19ebfa 100644 --- a/netbox/ipam/tests/test_views.py +++ b/netbox/ipam/tests/test_views.py @@ -920,7 +920,11 @@ class L2VPNTerminationTestCase( def setUpTestData(cls): device = create_test_device('Device 1') interface = Interface.objects.create(name='Interface 1', device=device, type='1000baset') - l2vpn = L2VPN.objects.create(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=650001) + l2vpns = ( + L2VPN(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=650001), + L2VPN(name='L2VPN 2', slug='l2vpn-2', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=650002), + ) + L2VPN.objects.bulk_create(l2vpns) vlans = ( VLAN(name='Vlan 1', vid=1001), @@ -933,14 +937,14 @@ class L2VPNTerminationTestCase( VLAN.objects.bulk_create(vlans) terminations = ( - L2VPNTermination(l2vpn=l2vpn, assigned_object=vlans[0]), - L2VPNTermination(l2vpn=l2vpn, assigned_object=vlans[1]), - L2VPNTermination(l2vpn=l2vpn, assigned_object=vlans[2]) + L2VPNTermination(l2vpn=l2vpns[0], assigned_object=vlans[0]), + L2VPNTermination(l2vpn=l2vpns[0], assigned_object=vlans[1]), + L2VPNTermination(l2vpn=l2vpns[0], assigned_object=vlans[2]) ) L2VPNTermination.objects.bulk_create(terminations) cls.form_data = { - 'l2vpn': l2vpn.pk, + 'l2vpn': l2vpns[0].pk, 'device': device.pk, 'interface': interface.pk, } @@ -953,10 +957,10 @@ class L2VPNTerminationTestCase( ) cls.csv_update_data = ( - "id,l2vpn", - f"{terminations[0].pk},L2VPN 2", - f"{terminations[1].pk},L2VPN 2", - f"{terminations[2].pk},L2VPN 2", + f"id,l2vpn", + f"{terminations[0].pk},{l2vpns[0].name}", + f"{terminations[1].pk},{l2vpns[0].name}", + f"{terminations[2].pk},{l2vpns[0].name}", ) cls.bulk_edit_data = {} diff --git a/netbox/netbox/tests/test_import.py b/netbox/netbox/tests/test_import.py index 73f2e0e27..b6f732bfe 100644 --- a/netbox/netbox/tests/test_import.py +++ b/netbox/netbox/tests/test_import.py @@ -3,6 +3,7 @@ from django.test import override_settings from dcim.models import * from users.models import ObjectPermission +from utilities.forms.choices import ImportFormatChoices from utilities.testing import ModelViewTestCase, create_tags @@ -27,7 +28,8 @@ class CSVImportTestCase(ModelViewTestCase): ) data = { - 'csv': self._get_csv_data(csv_data), + 'format': ImportFormatChoices.CSV, + 'data': self._get_csv_data(csv_data), } # Assign model-level permission @@ -67,7 +69,8 @@ class CSVImportTestCase(ModelViewTestCase): ) data = { - 'csv': self._get_csv_data(csv_data), + 'format': ImportFormatChoices.CSV, + 'data': self._get_csv_data(csv_data), } # Assign model-level permission diff --git a/netbox/netbox/views/generic/bulk_views.py b/netbox/netbox/views/generic/bulk_views.py index 5ab9e6da0..1a83c9de2 100644 --- a/netbox/netbox/views/generic/bulk_views.py +++ b/netbox/netbox/views/generic/bulk_views.py @@ -4,23 +4,22 @@ from copy import deepcopy from django.contrib import messages from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError from django.db import transaction, IntegrityError from django.db.models import ManyToManyField, ProtectedError from django.db.models.fields.reverse_related import ManyToManyRel -from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput +from django.forms import ModelMultipleChoiceField, MultipleHiddenInput from django.http import HttpResponse from django.shortcuts import get_object_or_404, redirect, render from django.utils.safestring import mark_safe from django_tables2.export import TableExport -from extras.models import ExportTemplate, SavedFilter +from extras.models import ExportTemplate from extras.signals import clear_webhooks from utilities.error_handlers import handle_protectederror -from utilities.exceptions import AbortRequest, PermissionsViolation -from utilities.forms import ( - BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, CSVFileField, restrict_form_fields, -) +from utilities.exceptions import AbortRequest, AbortTransaction, PermissionsViolation +from utilities.forms import BulkRenameForm, ConfirmationForm, ImportForm, restrict_form_fields +from utilities.forms.choices import ImportFormatChoices from utilities.htmx import is_htmx from utilities.permissions import get_permission_for_model from utilities.views import GetReturnURLMixin @@ -295,109 +294,136 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): """ template_name = 'generic/bulk_import.html' model_form = None + related_object_forms = dict() - def _import_form(self, *args, **kwargs): + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'add') - class ImportForm(BootstrapMixin, Form): - csv = CSVDataField( - from_form=self.model_form - ) - csv_file = CSVFileField( - label="CSV file", - from_form=self.model_form, - required=False - ) + def prep_related_object_data(self, parent, data): + """ + Hook to modify the data for related objects before it's passed to the related object form (for example, to + assign a parent object). + """ + return data - def clean(self): - csv_rows = self.cleaned_data['csv'][1] if 'csv' in self.cleaned_data else None - csv_file = self.files.get('csv_file') + def _save_object(self, model_form, request): - # Check that the user has not submitted both text data and a file - if csv_rows and csv_file: - raise ValidationError( - "Cannot process CSV text and file attachment simultaneously. Please choose only one import " - "method." - ) + # Save the primary object + obj = self.save_object(model_form, request) - return ImportForm(*args, **kwargs) + # Enforce object-level permissions + if not self.queryset.filter(pk=obj.pk).first(): + raise PermissionsViolation() - def _get_records(self, form, request): - if request.FILES: - headers, records = form.cleaned_data['csv_file'] - else: - headers, records = form.cleaned_data['csv'] + # Iterate through the related object forms (if any), validating and saving each instance. + for field_name, related_object_form in self.related_object_forms.items(): - return headers, records + related_obj_pks = [] + for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())): + rel_obj_data = self.prep_related_object_data(obj, rel_obj_data) + f = related_object_form(rel_obj_data) - def _update_objects(self, form, request, headers, records): - updated_objs = [] + for subfield_name, field in f.fields.items(): + if subfield_name not in rel_obj_data and hasattr(field, 'initial'): + f.data[subfield_name] = field.initial - ids = [int(record["id"]) for record in records] - qs = self.queryset.model.objects.filter(id__in=ids) - objs = {} - for obj in qs: - objs[obj.id] = obj + if f.is_valid(): + related_obj = f.save() + related_obj_pks.append(related_obj.pk) + else: + # Replicate errors on the related object form to the primary form for display + for subfield_name, errors in f.errors.items(): + for err in errors: + err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err) + model_form.add_error(None, err_msg) + raise AbortTransaction() - for row, data in enumerate(records, start=1): - if int(data["id"]) not in objs: - form.add_error('csv', f'Row {row} id: {data["id"]} Does not exist') - raise ValidationError("") + # Enforce object-level permissions on related objects + model = related_object_form.Meta.model + if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): + raise ObjectDoesNotExist - obj = objs[int(data["id"])] - obj_form = self.model_form(data, headers=headers, instance=obj) + return obj - # The form should only contain fields that are in the CSV - for name, field in list(obj_form.fields.items()): - if name not in headers: - del obj_form.fields[name] - - restrict_form_fields(obj_form, request.user) - - if obj_form.is_valid(): - obj = self._save_obj(obj_form, request) - updated_objs.append(obj) - else: - for field, err in obj_form.errors.items(): - form.add_error('csv', f'Row {row} {field}: {err[0]}') - raise ValidationError("") - - return updated_objs - - def _create_objects(self, form, request, headers, records): - new_objs = [] - - for row, data in enumerate(records, start=1): - obj_form = self.model_form(data, headers=headers) - restrict_form_fields(obj_form, request.user) - - if obj_form.is_valid(): - obj = self._save_obj(obj_form, request) - new_objs.append(obj) - else: - for field, err in obj_form.errors.items(): - form.add_error('csv', f'Row {row} {field}: {err[0]}') - raise ValidationError("") - - return new_objs - - def _save_obj(self, obj_form, request): + def save_object(self, obj_form, request): """ Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data). """ return obj_form.save() - def get_required_permission(self): - return get_permission_for_model(self.queryset.model, 'add') + def create_and_update_objects(self, form, request): + saved_objects = [] + + records = list(form.cleaned_data['data']) + + # Prefetch objects to be updated, if any + prefetch_ids = [int(record['id']) for record in records if record.get('id')] + prefetched_objects = { + obj.pk: obj + for obj in self.queryset.model.objects.filter(id__in=prefetch_ids) + } if prefetch_ids else {} + + for i, record in enumerate(records, start=1): + instance = None + object_id = int(record.pop('id')) if record.get('id') else None + + # Determine whether this object is being created or updated + if object_id: + try: + instance = prefetched_objects[object_id] + except KeyError: + form.add_error('data', f"Row {i}: Object with ID {object_id} does not exist") + raise ValidationError('') + + if form.cleaned_data['format'] == ImportFormatChoices.CSV: + model_form = self.model_form(record, instance=instance, headers=form._csv_headers) + else: + model_form = self.model_form(record, instance=instance) + # Assign default values for any fields which were not specified. + # We have to do this manually because passing 'initial=' to the form + # on initialization merely sets default values for the widgets. + # Since widgets are not used for YAML/JSON import, we first bind the + # imported data normally, then update the form's data with the applicable + # field defaults as needed prior to form validation. + for field_name, field in model_form.fields.items(): + if field_name not in record and hasattr(field, 'initial'): + model_form.data[field_name] = field.initial + + # When updating, omit all form fields other than those specified in the record. (No + # fields are required when modifying an existing object.) + if object_id: + unused_fields = [f for f in model_form.fields if f not in record] + for field_name in unused_fields: + del model_form.fields[field_name] + + restrict_form_fields(model_form, request.user) + + if model_form.is_valid(): + obj = self._save_object(model_form, request) + saved_objects.append(obj) + else: + # Replicate model form errors for display + for field, errors in model_form.errors.items(): + for err in errors: + if field == '__all__': + form.add_error(None, f'Record {i}: {err}') + else: + form.add_error(None, f'Record {i} {field}: {err}') + + raise ValidationError("") + + return saved_objects # # Request handlers # def get(self, request): + form = ImportForm() return render(request, self.template_name, { 'model': self.model_form._meta.model, - 'form': self._import_form(), + 'form': form, 'fields': self.model_form().fields, 'return_url': self.get_return_url(request), **self.get_extra_context(request), @@ -405,19 +431,16 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): def post(self, request): logger = logging.getLogger('netbox.views.BulkImportView') - form = self._import_form(request.POST, request.FILES) + + form = ImportForm(request.POST, request.FILES) if form.is_valid(): - logger.debug("Form validation was successful") + logger.debug("Import form validation was successful") try: - # Iterate through CSV data and bind each row to a new model form instance. + # Iterate through data and bind each record to a new model form instance. with transaction.atomic(): - headers, records = self._get_records(form, request) - if "id" in headers: - new_objs = self._update_objects(form, request, headers, records) - else: - new_objs = self._create_objects(form, request, headers, records) + new_objs = self.create_and_update_objects(form, request) # Enforce object-level permissions if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): diff --git a/netbox/netbox/views/generic/object_views.py b/netbox/netbox/views/generic/object_views.py index 0d122a41a..738d70786 100644 --- a/netbox/netbox/views/generic/object_views.py +++ b/netbox/netbox/views/generic/object_views.py @@ -2,7 +2,6 @@ import logging from copy import deepcopy from django.contrib import messages -from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import ProtectedError from django.shortcuts import redirect, render @@ -12,8 +11,8 @@ from django.utils.safestring import mark_safe from extras.signals import clear_webhooks from utilities.error_handlers import handle_protectederror -from utilities.exceptions import AbortRequest, AbortTransaction, PermissionsViolation -from utilities.forms import ConfirmationForm, ImportForm, restrict_form_fields +from utilities.exceptions import AbortRequest, PermissionsViolation +from utilities.forms import ConfirmationForm, restrict_form_fields from utilities.htmx import is_htmx from utilities.permissions import get_permission_for_model from utilities.utils import get_viewname, normalize_querydict, prepare_cloned_fields @@ -27,7 +26,6 @@ __all__ = ( 'ObjectChildrenView', 'ObjectDeleteView', 'ObjectEditView', - 'ObjectImportView', 'ObjectView', ) @@ -151,146 +149,6 @@ class ObjectChildrenView(ObjectView, ActionsMixin, TableMixin): }) -class ObjectImportView(GetReturnURLMixin, BaseObjectView): - """ - Import a single object (YAML or JSON format). - - Attributes: - model_form: The ModelForm used to create individual objects - related_object_forms: A dictionary mapping of forms to be used for the creation of related (child) objects - """ - template_name = 'generic/object_import.html' - model_form = None - related_object_forms = dict() - - def get_required_permission(self): - return get_permission_for_model(self.queryset.model, 'add') - - def prep_related_object_data(self, parent, data): - """ - Hook to modify the data for related objects before it's passed to the related object form (for example, to - assign a parent object). - """ - return data - - def _create_object(self, model_form): - - # Save the primary object - obj = model_form.save() - - # Enforce object-level permissions - if not self.queryset.filter(pk=obj.pk).exists(): - raise PermissionsViolation() - - # Iterate through the related object forms (if any), validating and saving each instance. - for field_name, related_object_form in self.related_object_forms.items(): - - related_obj_pks = [] - for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())): - rel_obj_data = self.prep_related_object_data(obj, rel_obj_data) - f = related_object_form(rel_obj_data) - - for subfield_name, field in f.fields.items(): - if subfield_name not in rel_obj_data and hasattr(field, 'initial'): - f.data[subfield_name] = field.initial - - if f.is_valid(): - related_obj = f.save() - related_obj_pks.append(related_obj.pk) - else: - # Replicate errors on the related object form to the primary form for display - for subfield_name, errors in f.errors.items(): - for err in errors: - err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err) - model_form.add_error(None, err_msg) - raise AbortTransaction() - - # Enforce object-level permissions on related objects - model = related_object_form.Meta.model - if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): - raise ObjectDoesNotExist - - return obj - - # - # Request handlers - # - - def get(self, request): - form = ImportForm() - - return render(request, self.template_name, { - 'form': form, - 'obj_type': self.queryset.model._meta.verbose_name, - 'return_url': self.get_return_url(request), - }) - - def post(self, request): - logger = logging.getLogger('netbox.views.ObjectImportView') - form = ImportForm(request.POST) - - if form.is_valid(): - logger.debug("Import form validation was successful") - - # Initialize model form - data = form.cleaned_data['data'] - model_form = self.model_form(data) - restrict_form_fields(model_form, request.user) - - # Assign default values for any fields which were not specified. We have to do this manually because passing - # 'initial=' to the form on initialization merely sets default values for the widgets. Since widgets are not - # used for YAML/JSON import, we first bind the imported data normally, then update the form's data with the - # applicable field defaults as needed prior to form validation. - for field_name, field in model_form.fields.items(): - if field_name not in data and hasattr(field, 'initial'): - model_form.data[field_name] = field.initial - - if model_form.is_valid(): - - try: - with transaction.atomic(): - obj = self._create_object(model_form) - - except AbortTransaction: - clear_webhooks.send(sender=self) - - except (AbortRequest, PermissionsViolation) as e: - logger.debug(e.message) - form.add_error(None, e.message) - clear_webhooks.send(sender=self) - - if not model_form.errors: - logger.info(f"Import object {obj} (PK: {obj.pk})") - msg = f'Imported object: {obj}' - messages.success(request, mark_safe(msg)) - - if '_addanother' in request.POST: - return redirect(request.get_full_path()) - - self.get_return_url(request, obj) - return redirect(self.get_return_url(request, obj)) - - else: - logger.debug("Model form validation failed") - - # Replicate model form errors for display - for field, errors in model_form.errors.items(): - for err in errors: - if field == '__all__': - form.add_error(None, err) - else: - form.add_error(None, "{}: {}".format(field, err)) - - else: - logger.debug("Import form validation failed") - - return render(request, self.template_name, { - 'form': form, - 'obj_type': self.queryset.model._meta.verbose_name, - 'return_url': self.get_return_url(request), - }) - - class ObjectEditView(GetReturnURLMixin, BaseObjectView): """ Create or edit a single object. diff --git a/netbox/templates/generic/bulk_import.html b/netbox/templates/generic/bulk_import.html index 1d638cb2c..4ddfb884c 100644 --- a/netbox/templates/generic/bulk_import.html +++ b/netbox/templates/generic/bulk_import.html @@ -15,142 +15,160 @@ Context: {% block tabs %}
{% endblock tabs %} {% block content-wrapper %}