mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-25 08:46:10 -06:00
Consolidate object import/update logic
This commit is contained in:
parent
ec053f550b
commit
797d839c9a
@ -345,14 +345,41 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
|||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def _create_objects(self, form, request):
|
def _save_obj(self, obj_form, request):
|
||||||
new_objs = []
|
"""
|
||||||
|
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
|
||||||
|
"""
|
||||||
|
return obj_form.save()
|
||||||
|
|
||||||
|
def create_and_update_objects(self, form, request):
|
||||||
|
created_objects = []
|
||||||
|
updated_objects = []
|
||||||
|
|
||||||
|
records = list(form.cleaned_data['data'])
|
||||||
|
|
||||||
|
# Prefetch objects to be updated, if any
|
||||||
|
prefetch_ids = [int(record['id']) for record in records if record.get('id')]
|
||||||
|
prefetched_objects = {
|
||||||
|
obj.pk: obj
|
||||||
|
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
|
||||||
|
} if prefetch_ids else {}
|
||||||
|
|
||||||
|
for i, record in enumerate(records, start=1):
|
||||||
|
instance = None
|
||||||
|
object_id = int(record.pop('id')) if record.get('id') else None
|
||||||
|
|
||||||
|
# Determine whether this object is being created or updated
|
||||||
|
if object_id:
|
||||||
|
try:
|
||||||
|
instance = prefetched_objects[object_id]
|
||||||
|
except KeyError:
|
||||||
|
form.add_error('data', f"Row {i}: Object with ID {object_id} does not exist")
|
||||||
|
raise ValidationError('')
|
||||||
|
|
||||||
for i, record in enumerate(form.cleaned_data['data'], start=1):
|
|
||||||
if form.cleaned_data['format'] == ImportFormatChoices.CSV:
|
if form.cleaned_data['format'] == ImportFormatChoices.CSV:
|
||||||
model_form = self.model_form(record, headers=form._csv_headers)
|
model_form = self.model_form(record, instance=instance, headers=form._csv_headers)
|
||||||
else:
|
else:
|
||||||
model_form = self.model_form(record)
|
model_form = self.model_form(record, instance=instance)
|
||||||
# Assign default values for any fields which were not specified.
|
# Assign default values for any fields which were not specified.
|
||||||
# We have to do this manually because passing 'initial=' to the form
|
# We have to do this manually because passing 'initial=' to the form
|
||||||
# on initialization merely sets default values for the widgets.
|
# on initialization merely sets default values for the widgets.
|
||||||
@ -362,11 +389,23 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
|||||||
for field_name, field in model_form.fields.items():
|
for field_name, field in model_form.fields.items():
|
||||||
if field_name not in record and hasattr(field, 'initial'):
|
if field_name not in record and hasattr(field, 'initial'):
|
||||||
model_form.data[field_name] = field.initial
|
model_form.data[field_name] = field.initial
|
||||||
|
|
||||||
|
# When updating, omit all form fields other than those specified in the record. (No
|
||||||
|
# fields are required when modifying an existing object.)
|
||||||
|
if object_id:
|
||||||
|
unused_fields = [f for f in model_form.fields if f not in record]
|
||||||
|
for field_name in unused_fields:
|
||||||
|
del model_form.fields[field_name]
|
||||||
|
|
||||||
restrict_form_fields(model_form, request.user)
|
restrict_form_fields(model_form, request.user)
|
||||||
|
|
||||||
if model_form.is_valid():
|
if model_form.is_valid():
|
||||||
obj = self._create_object(request, model_form)
|
if object_id:
|
||||||
new_objs.append(obj)
|
obj = self._save_obj(model_form, request)
|
||||||
|
updated_objects.append(obj)
|
||||||
|
else:
|
||||||
|
obj = self._create_object(request, model_form)
|
||||||
|
created_objects.append(obj)
|
||||||
else:
|
else:
|
||||||
# Replicate model form errors for display
|
# Replicate model form errors for display
|
||||||
for field, errors in model_form.errors.items():
|
for field, errors in model_form.errors.items():
|
||||||
@ -378,49 +417,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
|||||||
|
|
||||||
raise ValidationError("")
|
raise ValidationError("")
|
||||||
|
|
||||||
return new_objs
|
return [*created_objects, *updated_objects]
|
||||||
|
|
||||||
def _update_objects(self, form, request):
|
|
||||||
updated_objs = []
|
|
||||||
records = form.cleaned_data['data']
|
|
||||||
headers = form._csv_headers
|
|
||||||
|
|
||||||
ids = [int(record["id"]) for record in records]
|
|
||||||
qs = self.queryset.model.objects.filter(id__in=ids)
|
|
||||||
objs = {}
|
|
||||||
for obj in qs:
|
|
||||||
objs[obj.id] = obj
|
|
||||||
|
|
||||||
for row, data in enumerate(records, start=1):
|
|
||||||
if int(data["id"]) not in objs:
|
|
||||||
form.add_error('csv', f'Row {row} id: {data["id"]} Does not exist')
|
|
||||||
raise ValidationError("")
|
|
||||||
|
|
||||||
obj = objs[int(data["id"])]
|
|
||||||
obj_form = self.model_form(data, headers=headers, instance=obj)
|
|
||||||
|
|
||||||
# The form should only contain fields that are in the CSV
|
|
||||||
for name, field in list(obj_form.fields.items()):
|
|
||||||
if name not in headers:
|
|
||||||
del obj_form.fields[name]
|
|
||||||
|
|
||||||
restrict_form_fields(obj_form, request.user)
|
|
||||||
|
|
||||||
if obj_form.is_valid():
|
|
||||||
obj = self._save_obj(obj_form, request)
|
|
||||||
updated_objs.append(obj)
|
|
||||||
else:
|
|
||||||
for field, err in obj_form.errors.items():
|
|
||||||
form.add_error('data', f'Row {row} {field}: {err[0]}')
|
|
||||||
raise ValidationError("")
|
|
||||||
|
|
||||||
return updated_objs
|
|
||||||
|
|
||||||
def _save_obj(self, obj_form, request):
|
|
||||||
"""
|
|
||||||
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
|
|
||||||
"""
|
|
||||||
return obj_form.save()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Request handlers
|
# Request handlers
|
||||||
@ -448,13 +445,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
|
|||||||
try:
|
try:
|
||||||
# Iterate through data and bind each record to a new model form instance.
|
# Iterate through data and bind each record to a new model form instance.
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
if form.cleaned_data['format'] == 'csv':
|
new_objs = self.create_and_update_objects(form, request)
|
||||||
if 'id' in form._csv_headers:
|
|
||||||
new_objs = self._update_objects(form, request)
|
|
||||||
else:
|
|
||||||
new_objs = self._create_objects(form, request)
|
|
||||||
else:
|
|
||||||
new_objs = self._create_objects(form, request)
|
|
||||||
|
|
||||||
# Enforce object-level permissions
|
# Enforce object-level permissions
|
||||||
if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
|
if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
|
||||||
|
@ -123,15 +123,14 @@ class CSVModelForm(forms.ModelForm):
|
|||||||
"""
|
"""
|
||||||
ModelForm used for the import of objects in CSV format.
|
ModelForm used for the import of objects in CSV format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, headers=None, **kwargs):
|
def __init__(self, *args, headers=None, **kwargs):
|
||||||
|
headers = headers or {}
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Modify the model form to accommodate any customized to_field_name properties
|
# Modify the model form to accommodate any customized to_field_name properties
|
||||||
if headers:
|
for field, to_field in headers.items():
|
||||||
for field, to_field in headers.items():
|
if to_field is not None:
|
||||||
if to_field is not None:
|
self.fields[field].to_field_name = to_field
|
||||||
self.fields[field].to_field_name = to_field
|
|
||||||
|
|
||||||
|
|
||||||
class ImportForm(BootstrapMixin, forms.Form):
|
class ImportForm(BootstrapMixin, forms.Form):
|
||||||
|
Loading…
Reference in New Issue
Block a user