mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-25 08:46:10 -06:00
Consolidate object import/update logic
This commit is contained in:
parent
ec053f550b
commit
797d839c9a
@ -345,14 +345,41 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
||||
|
||||
return obj
|
||||
|
||||
def _create_objects(self, form, request):
|
||||
new_objs = []
|
||||
def _save_obj(self, obj_form, request):
|
||||
"""
|
||||
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
|
||||
"""
|
||||
return obj_form.save()
|
||||
|
||||
def create_and_update_objects(self, form, request):
|
||||
created_objects = []
|
||||
updated_objects = []
|
||||
|
||||
records = list(form.cleaned_data['data'])
|
||||
|
||||
# Prefetch objects to be updated, if any
|
||||
prefetch_ids = [int(record['id']) for record in records if record.get('id')]
|
||||
prefetched_objects = {
|
||||
obj.pk: obj
|
||||
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
|
||||
} if prefetch_ids else {}
|
||||
|
||||
for i, record in enumerate(records, start=1):
|
||||
instance = None
|
||||
object_id = int(record.pop('id')) if record.get('id') else None
|
||||
|
||||
# Determine whether this object is being created or updated
|
||||
if object_id:
|
||||
try:
|
||||
instance = prefetched_objects[object_id]
|
||||
except KeyError:
|
||||
form.add_error('data', f"Row {i}: Object with ID {object_id} does not exist")
|
||||
raise ValidationError('')
|
||||
|
||||
for i, record in enumerate(form.cleaned_data['data'], start=1):
|
||||
if form.cleaned_data['format'] == ImportFormatChoices.CSV:
|
||||
model_form = self.model_form(record, headers=form._csv_headers)
|
||||
model_form = self.model_form(record, instance=instance, headers=form._csv_headers)
|
||||
else:
|
||||
model_form = self.model_form(record)
|
||||
model_form = self.model_form(record, instance=instance)
|
||||
# Assign default values for any fields which were not specified.
|
||||
# We have to do this manually because passing 'initial=' to the form
|
||||
# on initialization merely sets default values for the widgets.
|
||||
@ -362,11 +389,23 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
||||
for field_name, field in model_form.fields.items():
|
||||
if field_name not in record and hasattr(field, 'initial'):
|
||||
model_form.data[field_name] = field.initial
|
||||
|
||||
# When updating, omit all form fields other than those specified in the record. (No
|
||||
# fields are required when modifying an existing object.)
|
||||
if object_id:
|
||||
unused_fields = [f for f in model_form.fields if f not in record]
|
||||
for field_name in unused_fields:
|
||||
del model_form.fields[field_name]
|
||||
|
||||
restrict_form_fields(model_form, request.user)
|
||||
|
||||
if model_form.is_valid():
|
||||
obj = self._create_object(request, model_form)
|
||||
new_objs.append(obj)
|
||||
if object_id:
|
||||
obj = self._save_obj(model_form, request)
|
||||
updated_objects.append(obj)
|
||||
else:
|
||||
obj = self._create_object(request, model_form)
|
||||
created_objects.append(obj)
|
||||
else:
|
||||
# Replicate model form errors for display
|
||||
for field, errors in model_form.errors.items():
|
||||
@ -378,49 +417,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
||||
|
||||
raise ValidationError("")
|
||||
|
||||
return new_objs
|
||||
|
||||
def _update_objects(self, form, request):
|
||||
updated_objs = []
|
||||
records = form.cleaned_data['data']
|
||||
headers = form._csv_headers
|
||||
|
||||
ids = [int(record["id"]) for record in records]
|
||||
qs = self.queryset.model.objects.filter(id__in=ids)
|
||||
objs = {}
|
||||
for obj in qs:
|
||||
objs[obj.id] = obj
|
||||
|
||||
for row, data in enumerate(records, start=1):
|
||||
if int(data["id"]) not in objs:
|
||||
form.add_error('csv', f'Row {row} id: {data["id"]} Does not exist')
|
||||
raise ValidationError("")
|
||||
|
||||
obj = objs[int(data["id"])]
|
||||
obj_form = self.model_form(data, headers=headers, instance=obj)
|
||||
|
||||
# The form should only contain fields that are in the CSV
|
||||
for name, field in list(obj_form.fields.items()):
|
||||
if name not in headers:
|
||||
del obj_form.fields[name]
|
||||
|
||||
restrict_form_fields(obj_form, request.user)
|
||||
|
||||
if obj_form.is_valid():
|
||||
obj = self._save_obj(obj_form, request)
|
||||
updated_objs.append(obj)
|
||||
else:
|
||||
for field, err in obj_form.errors.items():
|
||||
form.add_error('data', f'Row {row} {field}: {err[0]}')
|
||||
raise ValidationError("")
|
||||
|
||||
return updated_objs
|
||||
|
||||
def _save_obj(self, obj_form, request):
|
||||
"""
|
||||
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
|
||||
"""
|
||||
return obj_form.save()
|
||||
return [*created_objects, *updated_objects]
|
||||
|
||||
#
|
||||
# Request handlers
|
||||
@ -448,13 +445,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
||||
try:
|
||||
# Iterate through data and bind each record to a new model form instance.
|
||||
with transaction.atomic():
|
||||
if form.cleaned_data['format'] == 'csv':
|
||||
if 'id' in form._csv_headers:
|
||||
new_objs = self._update_objects(form, request)
|
||||
else:
|
||||
new_objs = self._create_objects(form, request)
|
||||
else:
|
||||
new_objs = self._create_objects(form, request)
|
||||
new_objs = self.create_and_update_objects(form, request)
|
||||
|
||||
# Enforce object-level permissions
|
||||
if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
|
||||
|
@ -123,15 +123,14 @@ class CSVModelForm(forms.ModelForm):
|
||||
"""
|
||||
ModelForm used for the import of objects in CSV format.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, headers=None, **kwargs):
|
||||
headers = headers or {}
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Modify the model form to accommodate any customized to_field_name properties
|
||||
if headers:
|
||||
for field, to_field in headers.items():
|
||||
if to_field is not None:
|
||||
self.fields[field].to_field_name = to_field
|
||||
for field, to_field in headers.items():
|
||||
if to_field is not None:
|
||||
self.fields[field].to_field_name = to_field
|
||||
|
||||
|
||||
class ImportForm(BootstrapMixin, forms.Form):
|
||||
|
Loading…
Reference in New Issue
Block a user