Clean up bulk import view

This commit is contained in:
jeremystretch 2022-11-10 09:05:16 -05:00
parent 797d839c9a
commit 6793e384df
3 changed files with 18 additions and 16 deletions

View File

@ -807,7 +807,7 @@ class RackReservationImportView(generic.BulkImportView):
model_form = forms.RackReservationCSVForm model_form = forms.RackReservationCSVForm
table = tables.RackReservationTable table = tables.RackReservationTable
def _save_obj(self, obj_form, request): def save_object(self, obj_form, request):
""" """
Assign the currently authenticated user to the RackReservation. Assign the currently authenticated user to the RackReservation.
""" """
@ -2028,8 +2028,7 @@ class ChildDeviceBulkImportView(generic.BulkImportView):
table = tables.DeviceImportTable table = tables.DeviceImportTable
template_name = 'dcim/device_import_child.html' template_name = 'dcim/device_import_child.html'
def _save_obj(self, obj_form, request): def save_object(self, obj_form, request):
obj = obj_form.save() obj = obj_form.save()
# Save the reverse relation to the parent device bay # Save the reverse relation to the parent device bay

View File

@ -306,10 +306,10 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
""" """
return data return data
def _create_object(self, request, model_form): def _save_object(self, model_form, request):
# Save the primary object # Save the primary object
obj = self._save_obj(model_form, request) obj = self.save_object(model_form, request)
# Enforce object-level permissions # Enforce object-level permissions
if not self.queryset.filter(pk=obj.pk).first(): if not self.queryset.filter(pk=obj.pk).first():
@ -345,15 +345,14 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
return obj return obj
def _save_obj(self, obj_form, request): def save_object(self, obj_form, request):
""" """
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data). Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
""" """
return obj_form.save() return obj_form.save()
def create_and_update_objects(self, form, request): def create_and_update_objects(self, form, request):
created_objects = [] saved_objects = []
updated_objects = []
records = list(form.cleaned_data['data']) records = list(form.cleaned_data['data'])
@ -400,12 +399,8 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
restrict_form_fields(model_form, request.user) restrict_form_fields(model_form, request.user)
if model_form.is_valid(): if model_form.is_valid():
if object_id: obj = self._save_object(model_form, request)
obj = self._save_obj(model_form, request) saved_objects.append(obj)
updated_objects.append(obj)
else:
obj = self._create_object(request, model_form)
created_objects.append(obj)
else: else:
# Replicate model form errors for display # Replicate model form errors for display
for field, errors in model_form.errors.items(): for field, errors in model_form.errors.items():
@ -417,7 +412,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
raise ValidationError("") raise ValidationError("")
return [*created_objects, *updated_objects] return saved_objects
# #
# Request handlers # Request handlers

View File

@ -123,8 +123,9 @@ class CSVModelForm(forms.ModelForm):
""" """
ModelForm used for the import of objects in CSV format. ModelForm used for the import of objects in CSV format.
""" """
def __init__(self, *args, headers=None, **kwargs): def __init__(self, *args, headers=None, fields=None, **kwargs):
headers = headers or {} headers = headers or {}
fields = fields or []
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Modify the model form to accommodate any customized to_field_name properties # Modify the model form to accommodate any customized to_field_name properties
@ -132,6 +133,13 @@ class CSVModelForm(forms.ModelForm):
if to_field is not None: if to_field is not None:
self.fields[field].to_field_name = to_field self.fields[field].to_field_name = to_field
# Omit any fields not specified (e.g. because the form is being used to
# updated rather than create objects)
if fields:
for field in list(self.fields.keys()):
if field not in fields:
del self.fields[field]
class ImportForm(BootstrapMixin, forms.Form): class ImportForm(BootstrapMixin, forms.Form):
data = forms.CharField( data = forms.CharField(