Consolidate object import/update logic

This commit is contained in:
jeremystretch 2022-11-09 17:44:02 -05:00
parent ec053f550b
commit 797d839c9a
2 changed files with 52 additions and 62 deletions

View File

@ -345,14 +345,41 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
return obj
def _create_objects(self, form, request):
new_objs = []
def _save_obj(self, obj_form, request):
"""
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
"""
return obj_form.save()
def create_and_update_objects(self, form, request):
created_objects = []
updated_objects = []
records = list(form.cleaned_data['data'])
# Prefetch objects to be updated, if any
prefetch_ids = [int(record['id']) for record in records if record.get('id')]
prefetched_objects = {
obj.pk: obj
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
} if prefetch_ids else {}
for i, record in enumerate(records, start=1):
instance = None
object_id = int(record.pop('id')) if record.get('id') else None
# Determine whether this object is being created or updated
if object_id:
try:
instance = prefetched_objects[object_id]
except KeyError:
form.add_error('data', f"Row {i}: Object with ID {object_id} does not exist")
raise ValidationError('')
for i, record in enumerate(form.cleaned_data['data'], start=1):
if form.cleaned_data['format'] == ImportFormatChoices.CSV:
model_form = self.model_form(record, headers=form._csv_headers)
model_form = self.model_form(record, instance=instance, headers=form._csv_headers)
else:
model_form = self.model_form(record)
model_form = self.model_form(record, instance=instance)
# Assign default values for any fields which were not specified.
# We have to do this manually because passing 'initial=' to the form
# on initialization merely sets default values for the widgets.
@ -362,11 +389,23 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
for field_name, field in model_form.fields.items():
if field_name not in record and hasattr(field, 'initial'):
model_form.data[field_name] = field.initial
# When updating, omit all form fields other than those specified in the record. (No
# fields are required when modifying an existing object.)
if object_id:
unused_fields = [f for f in model_form.fields if f not in record]
for field_name in unused_fields:
del model_form.fields[field_name]
restrict_form_fields(model_form, request.user)
if model_form.is_valid():
if object_id:
obj = self._save_obj(model_form, request)
updated_objects.append(obj)
else:
obj = self._create_object(request, model_form)
new_objs.append(obj)
created_objects.append(obj)
else:
# Replicate model form errors for display
for field, errors in model_form.errors.items():
@ -378,49 +417,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
raise ValidationError("")
return new_objs
def _update_objects(self, form, request):
updated_objs = []
records = form.cleaned_data['data']
headers = form._csv_headers
ids = [int(record["id"]) for record in records]
qs = self.queryset.model.objects.filter(id__in=ids)
objs = {}
for obj in qs:
objs[obj.id] = obj
for row, data in enumerate(records, start=1):
if int(data["id"]) not in objs:
form.add_error('csv', f'Row {row} id: {data["id"]} Does not exist')
raise ValidationError("")
obj = objs[int(data["id"])]
obj_form = self.model_form(data, headers=headers, instance=obj)
# The form should only contain fields that are in the CSV
for name, field in list(obj_form.fields.items()):
if name not in headers:
del obj_form.fields[name]
restrict_form_fields(obj_form, request.user)
if obj_form.is_valid():
obj = self._save_obj(obj_form, request)
updated_objs.append(obj)
else:
for field, err in obj_form.errors.items():
form.add_error('data', f'Row {row} {field}: {err[0]}')
raise ValidationError("")
return updated_objs
def _save_obj(self, obj_form, request):
"""
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
"""
return obj_form.save()
return [*created_objects, *updated_objects]
#
# Request handlers
@ -448,13 +445,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
try:
# Iterate through data and bind each record to a new model form instance.
with transaction.atomic():
if form.cleaned_data['format'] == 'csv':
if 'id' in form._csv_headers:
new_objs = self._update_objects(form, request)
else:
new_objs = self._create_objects(form, request)
else:
new_objs = self._create_objects(form, request)
new_objs = self.create_and_update_objects(form, request)
# Enforce object-level permissions
if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):

View File

@ -123,12 +123,11 @@ class CSVModelForm(forms.ModelForm):
"""
ModelForm used for the import of objects in CSV format.
"""
def __init__(self, *args, headers=None, **kwargs):
headers = headers or {}
super().__init__(*args, **kwargs)
# Modify the model form to accommodate any customized to_field_name properties
if headers:
for field, to_field in headers.items():
if to_field is not None:
self.fields[field].to_field_name = to_field