From 797d839c9a20b732e12984869d5735b4a65f6133 Mon Sep 17 00:00:00 2001 From: jeremystretch Date: Wed, 9 Nov 2022 17:44:02 -0500 Subject: [PATCH] Consolidate object import/update logic --- netbox/netbox/views/generic/bulk_views.py | 105 ++++++++++------------ netbox/utilities/forms/forms.py | 9 +- 2 files changed, 52 insertions(+), 62 deletions(-) diff --git a/netbox/netbox/views/generic/bulk_views.py b/netbox/netbox/views/generic/bulk_views.py index 2f9d5354e..33ef071a8 100644 --- a/netbox/netbox/views/generic/bulk_views.py +++ b/netbox/netbox/views/generic/bulk_views.py @@ -345,14 +345,41 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): return obj - def _create_objects(self, form, request): - new_objs = [] + def _save_obj(self, obj_form, request): + """ + Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data). + """ + return obj_form.save() + + def create_and_update_objects(self, form, request): + created_objects = [] + updated_objects = [] + + records = list(form.cleaned_data['data']) + + # Prefetch objects to be updated, if any + prefetch_ids = [int(record['id']) for record in records if record.get('id')] + prefetched_objects = { + obj.pk: obj + for obj in self.queryset.model.objects.filter(id__in=prefetch_ids) + } if prefetch_ids else {} + + for i, record in enumerate(records, start=1): + instance = None + object_id = int(record.pop('id')) if record.get('id') else None + + # Determine whether this object is being created or updated + if object_id: + try: + instance = prefetched_objects[object_id] + except KeyError: + form.add_error('data', f"Row {i}: Object with ID {object_id} does not exist") + raise ValidationError('') - for i, record in enumerate(form.cleaned_data['data'], start=1): if form.cleaned_data['format'] == ImportFormatChoices.CSV: - model_form = self.model_form(record, headers=form._csv_headers) + model_form = self.model_form(record, instance=instance, headers=form._csv_headers) else: - model_form = self.model_form(record) + model_form = self.model_form(record, instance=instance) # Assign default values for any fields which were not specified. # We have to do this manually because passing 'initial=' to the form # on initialization merely sets default values for the widgets. @@ -362,11 +389,23 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): for field_name, field in model_form.fields.items(): if field_name not in record and hasattr(field, 'initial'): model_form.data[field_name] = field.initial + + # When updating, omit all form fields other than those specified in the record. (No + # fields are required when modifying an existing object.) + if object_id: + unused_fields = [f for f in model_form.fields if f not in record] + for field_name in unused_fields: + del model_form.fields[field_name] + restrict_form_fields(model_form, request.user) if model_form.is_valid(): - obj = self._create_object(request, model_form) - new_objs.append(obj) + if object_id: + obj = self._save_obj(model_form, request) + updated_objects.append(obj) + else: + obj = self._create_object(request, model_form) + created_objects.append(obj) else: # Replicate model form errors for display for field, errors in model_form.errors.items(): @@ -378,49 +417,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): raise ValidationError("") - return new_objs - - def _update_objects(self, form, request): - updated_objs = [] - records = form.cleaned_data['data'] - headers = form._csv_headers - - ids = [int(record["id"]) for record in records] - qs = self.queryset.model.objects.filter(id__in=ids) - objs = {} - for obj in qs: - objs[obj.id] = obj - - for row, data in enumerate(records, start=1): - if int(data["id"]) not in objs: - form.add_error('csv', f'Row {row} id: {data["id"]} Does not exist') - raise ValidationError("") - - obj = objs[int(data["id"])] - obj_form = self.model_form(data, headers=headers, instance=obj) - - # The form should only contain fields that are in the CSV - for name, field in list(obj_form.fields.items()): - if name not in headers: - del obj_form.fields[name] - - restrict_form_fields(obj_form, request.user) - - if obj_form.is_valid(): - obj = self._save_obj(obj_form, request) - updated_objs.append(obj) - else: - for field, err in obj_form.errors.items(): - form.add_error('data', f'Row {row} {field}: {err[0]}') - raise ValidationError("") - - return updated_objs - - def _save_obj(self, obj_form, request): - """ - Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data). - """ - return obj_form.save() + return [*created_objects, *updated_objects] # # Request handlers @@ -448,13 +445,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView): try: # Iterate through data and bind each record to a new model form instance. with transaction.atomic(): - if form.cleaned_data['format'] == 'csv': - if 'id' in form._csv_headers: - new_objs = self._update_objects(form, request) - else: - new_objs = self._create_objects(form, request) - else: - new_objs = self._create_objects(form, request) + new_objs = self.create_and_update_objects(form, request) # Enforce object-level permissions if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): diff --git a/netbox/utilities/forms/forms.py b/netbox/utilities/forms/forms.py index 096de0acb..b63b78895 100644 --- a/netbox/utilities/forms/forms.py +++ b/netbox/utilities/forms/forms.py @@ -123,15 +123,14 @@ class CSVModelForm(forms.ModelForm): """ ModelForm used for the import of objects in CSV format. """ - def __init__(self, *args, headers=None, **kwargs): + headers = headers or {} super().__init__(*args, **kwargs) # Modify the model form to accommodate any customized to_field_name properties - if headers: - for field, to_field in headers.items(): - if to_field is not None: - self.fields[field].to_field_name = to_field + for field, to_field in headers.items(): + if to_field is not None: + self.fields[field].to_field_name = to_field class ImportForm(BootstrapMixin, forms.Form):