diff --git a/netbox/netbox/views/generic/bulk_views.py b/netbox/netbox/views/generic/bulk_views.py index 9836bde88..8e49a9c92 100644 --- a/netbox/netbox/views/generic/bulk_views.py +++ b/netbox/netbox/views/generic/bulk_views.py @@ -8,7 +8,7 @@ from django.core.exceptions import FieldDoesNotExist, ValidationError from django.db import transaction, IntegrityError from django.db.models import ManyToManyField, ProtectedError from django.db.models.fields.reverse_related import ManyToManyRel -from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput +from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, model_to_dict from django.http import HttpResponse from django.shortcuts import get_object_or_404, redirect, render from django_tables2.export import TableExport @@ -330,13 +330,34 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): return headers, records def _update_objects(self, form, request, headers, records): + from utilities.forms import CSVModelChoiceField new_objs = [] for row, data in enumerate(records, start=1): - data = self.queryset.model.get(pk=data["pk"]) | data - obj_form = self.model_form(data, headers=headers) + obj = self.queryset.model.objects.get(pk=data["pk"]) + obj_form = self.model_form(instance=obj) + + save_data = model_to_dict(obj) + new_data = data + for name, field in obj_form.fields.items(): + if name == "manufacturer": + breakpoint() + if field.required and name not in obj_form.data: + if type(field) == CSVModelChoiceField and name in save_data: + # rel_field = field.queryset.get(pk=save_data[name]) + # to_name = getattr(field, 'to_field_name') or 'pk' + # obj_form.data[name] = getattr(rel_field, to_name) + new_data[name] = getattr(field.queryset.get(pk=save_data[name]), getattr(field, 'to_field_name') or 'pk') + else: + if name in save_data: + new_data[name] = save_data[name] + + obj_form = self.model_form(new_data, headers=headers, instance=obj) + # obj_form = self.model_form(save_data, instance=obj) + restrict_form_fields(obj_form, request.user) + breakpoint() if obj_form.is_valid(): obj = self._save_obj(obj_form, request) new_objs.append(obj) diff --git a/netbox/utilities/forms/utils.py b/netbox/utilities/forms/utils.py index a6f037e0b..48a68f13e 100644 --- a/netbox/utilities/forms/utils.py +++ b/netbox/utilities/forms/utils.py @@ -220,7 +220,11 @@ def validate_csv(headers, fields, required_fields): if parsed csv data contains invalid headers or does not contain required headers. """ # Validate provided column headers + is_update = False for field, to_field in headers.items(): + if field == "pk": + is_update = True + continue if field not in fields: raise forms.ValidationError(f'Unexpected column header "{field}" found.') if to_field and not hasattr(fields[field], 'to_field_name'): @@ -228,7 +232,8 @@ def validate_csv(headers, fields, required_fields): if to_field and not hasattr(fields[field].queryset.model, to_field): raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}') - # Validate required fields - for f in required_fields: - if f not in headers: - raise forms.ValidationError(f'Required column header "{f}" not found.') + # Validate required fields (if not an update) + if not is_update: + for f in required_fields: + if f not in headers: + raise forms.ValidationError(f'Required column header "{f}" not found.')