From c85a45e5208d21c9b48dd705ecfa6a1e520b548c Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Mon, 24 Aug 2020 14:11:13 -0400 Subject: [PATCH] Further work on custom fields --- netbox/extras/api/customfields.py | 4 +- netbox/extras/forms.py | 19 +--- netbox/extras/models/customfields.py | 10 ++ netbox/extras/tests/test_customfields.py | 126 ++++++++++------------- 4 files changed, 67 insertions(+), 92 deletions(-) diff --git a/netbox/extras/api/customfields.py b/netbox/extras/api/customfields.py index a0238129b..df00d0c1d 100644 --- a/netbox/extras/api/customfields.py +++ b/netbox/extras/api/customfields.py @@ -148,10 +148,10 @@ class CustomFieldModelSerializer(ValidatedModelSerializer): fields = CustomField.objects.filter(obj_type=content_type) # Populate CustomFieldValues for each instance from database - try: + if type(self.instance) in (list, tuple): for obj in self.instance: self._populate_custom_fields(obj, fields) - except TypeError: + else: self._populate_custom_fields(self.instance, fields) def _populate_custom_fields(self, instance, custom_fields): diff --git a/netbox/extras/forms.py b/netbox/extras/forms.py index 40c675c4d..96290ef0a 100644 --- a/netbox/extras/forms.py +++ b/netbox/extras/forms.py @@ -57,26 +57,13 @@ class CustomFieldModelForm(forms.ModelForm): # Annotate the field in the list of CustomField form fields self.custom_fields.append(field_name) - def _save_custom_fields(self): - - for field_name in self.custom_fields: - self.instance.custom_field_data[field_name[3:]] = self.cleaned_data[field_name] - def save(self, commit=True): - # Cache custom field values on object prior to save to ensure change logging + # Save custom field data on instance for cf_name in self.custom_fields: - self.instance._cf[cf_name[3:]] = self.cleaned_data.get(cf_name) + self.instance.custom_field_data[cf_name[3:]] = self.cleaned_data.get(cf_name) - obj = super().save(commit) - - # Handle custom fields the same way we do M2M fields - if commit: - self._save_custom_fields() - else: - obj.save_custom_fields = self._save_custom_fields - - return obj + return super().save(commit) class CustomFieldModelCSVForm(CSVModelForm, CustomFieldModelForm): diff --git a/netbox/extras/models/customfields.py b/netbox/extras/models/customfields.py index b0ea76cef..166ef5708 100644 --- a/netbox/extras/models/customfields.py +++ b/netbox/extras/models/customfields.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from datetime import date from django import forms @@ -34,6 +35,15 @@ class CustomFieldModel(models.Model): """ return self.custom_field_data + def get_custom_fields(self): + """ + Return a dictionary of custom fields for a single object in the form {: value}. + """ + fields = CustomField.objects.get_for_model(self) + return OrderedDict([ + (field, self.custom_field_data.get(field.name)) for field in fields + ]) + class CustomFieldManager(models.Manager): use_in_migrations = True diff --git a/netbox/extras/tests/test_customfields.py b/netbox/extras/tests/test_customfields.py index 74c0e7c3b..71254ac05 100644 --- a/netbox/extras/tests/test_customfields.py +++ b/netbox/extras/tests/test_customfields.py @@ -174,7 +174,7 @@ class CustomFieldAPITest(APITestCase): } cls.sites[1].save() - def test_get_single_object_without_custom_field_values(self): + def test_get_single_object_without_custom_field_data(self): """ Validate that custom fields are present on an object even if it has no values defined. """ @@ -192,13 +192,11 @@ class CustomFieldAPITest(APITestCase): 'choice_field': None, }) - def test_get_single_object_with_custom_field_values(self): + def test_get_single_object_with_custom_field_data(self): """ Validate that custom fields are present and correctly set for an object with values defined. """ - site2_cfvs = { - cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all() - } + site2_cfvs = self.sites[1].custom_field_data url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk}) self.add_permissions('dcim.view_site') @@ -236,15 +234,12 @@ class CustomFieldAPITest(APITestCase): # Validate database data site = Site.objects.get(pk=response.data['id']) - cfvs = { - cfv.field.name: cfv.value for cfv in site.custom_field_values.all() - } - self.assertEqual(cfvs['text_field'], self.cf_text.default) - self.assertEqual(cfvs['number_field'], self.cf_integer.default) - self.assertEqual(cfvs['boolean_field'], self.cf_boolean.default) - self.assertEqual(str(cfvs['date_field']), self.cf_date.default) - self.assertEqual(cfvs['url_field'], self.cf_url.default) - self.assertEqual(cfvs['choice_field'].pk, self.cf_select_choice1.pk) + self.assertEqual(site.custom_field_data['text_field'], self.cf_text.default) + self.assertEqual(site.custom_field_data['number_field'], self.cf_integer.default) + self.assertEqual(site.custom_field_data['boolean_field'], self.cf_boolean.default) + self.assertEqual(str(site.custom_field_data['date_field']), self.cf_date.default) + self.assertEqual(site.custom_field_data['url_field'], self.cf_url.default) + self.assertEqual(site.custom_field_data['choice_field'].pk, self.cf_select_choice1.pk) def test_create_single_object_with_values(self): """ @@ -280,15 +275,12 @@ class CustomFieldAPITest(APITestCase): # Validate database data site = Site.objects.get(pk=response.data['id']) - cfvs = { - cfv.field.name: cfv.value for cfv in site.custom_field_values.all() - } - self.assertEqual(cfvs['text_field'], data_cf['text_field']) - self.assertEqual(cfvs['number_field'], data_cf['number_field']) - self.assertEqual(cfvs['boolean_field'], data_cf['boolean_field']) - self.assertEqual(str(cfvs['date_field']), data_cf['date_field']) - self.assertEqual(cfvs['url_field'], data_cf['url_field']) - self.assertEqual(cfvs['choice_field'].pk, data_cf['choice_field']) + self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field']) + self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field']) + self.assertEqual(site.custom_field_data['boolean_field'], data_cf['boolean_field']) + self.assertEqual(str(site.custom_field_data['date_field']), data_cf['date_field']) + self.assertEqual(site.custom_field_data['url_field'], data_cf['url_field']) + self.assertEqual(site.custom_field_data['choice_field'].pk, data_cf['choice_field']) def test_create_multiple_objects_with_defaults(self): """ @@ -329,15 +321,12 @@ class CustomFieldAPITest(APITestCase): # Validate database data site = Site.objects.get(pk=response.data[i]['id']) - cfvs = { - cfv.field.name: cfv.value for cfv in site.custom_field_values.all() - } - self.assertEqual(cfvs['text_field'], self.cf_text.default) - self.assertEqual(cfvs['number_field'], self.cf_integer.default) - self.assertEqual(cfvs['boolean_field'], self.cf_boolean.default) - self.assertEqual(str(cfvs['date_field']), self.cf_date.default) - self.assertEqual(cfvs['url_field'], self.cf_url.default) - self.assertEqual(cfvs['choice_field'].pk, self.cf_select_choice1.pk) + self.assertEqual(site.custom_field_data['text_field'], self.cf_text.default) + self.assertEqual(site.custom_field_data['number_field'], self.cf_integer.default) + self.assertEqual(site.custom_field_data['boolean_field'], self.cf_boolean.default) + self.assertEqual(str(site.custom_field_data['date_field']), self.cf_date.default) + self.assertEqual(site.custom_field_data['url_field'], self.cf_url.default) + self.assertEqual(site.custom_field_data['choice_field'].pk, self.cf_select_choice1.pk) def test_create_multiple_objects_with_values(self): """ @@ -388,24 +377,20 @@ class CustomFieldAPITest(APITestCase): # Validate database data site = Site.objects.get(pk=response.data[i]['id']) - cfvs = { - cfv.field.name: cfv.value for cfv in site.custom_field_values.all() - } - self.assertEqual(cfvs['text_field'], custom_field_data['text_field']) - self.assertEqual(cfvs['number_field'], custom_field_data['number_field']) - self.assertEqual(cfvs['boolean_field'], custom_field_data['boolean_field']) - self.assertEqual(str(cfvs['date_field']), custom_field_data['date_field']) - self.assertEqual(cfvs['url_field'], custom_field_data['url_field']) - self.assertEqual(cfvs['choice_field'].pk, custom_field_data['choice_field']) + self.assertEqual(site.custom_field_data['text_field'], custom_field_data['text_field']) + self.assertEqual(site.custom_field_data['number_field'], custom_field_data['number_field']) + self.assertEqual(site.custom_field_data['boolean_field'], custom_field_data['boolean_field']) + self.assertEqual(str(site.custom_field_data['date_field']), custom_field_data['date_field']) + self.assertEqual(site.custom_field_data['url_field'], custom_field_data['url_field']) + self.assertEqual(site.custom_field_data['choice_field'].pk, custom_field_data['choice_field']) def test_update_single_object_with_values(self): """ Update an object with existing custom field values. Ensure that only the updated custom field values are modified. """ - site2_original_cfvs = { - cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all() - } + site = self.sites[1] + original_cfvs = {**site.custom_field_data} data = { 'custom_fields': { 'text_field': 'ABCD', @@ -430,15 +415,13 @@ class CustomFieldAPITest(APITestCase): # self.assertEqual(response_cf['choice_field']['label'], site2_original_cfvs['choice_field'].value) # Validate database data - site2_updated_cfvs = { - cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all() - } - self.assertEqual(site2_updated_cfvs['text_field'], data_cf['text_field']) - self.assertEqual(site2_updated_cfvs['number_field'], data_cf['number_field']) - self.assertEqual(site2_updated_cfvs['boolean_field'], site2_original_cfvs['boolean_field']) - self.assertEqual(site2_updated_cfvs['date_field'], site2_original_cfvs['date_field']) - self.assertEqual(site2_updated_cfvs['url_field'], site2_original_cfvs['url_field']) - self.assertEqual(site2_updated_cfvs['choice_field'], site2_original_cfvs['choice_field']) + site.refresh_from_db() + self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field']) + self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field']) + self.assertEqual(site.custom_field_data['boolean_field'], original_cfvs['boolean_field']) + self.assertEqual(site.custom_field_data['date_field'], original_cfvs['date_field']) + self.assertEqual(site.custom_field_data['url_field'], original_cfvs['url_field']) + self.assertEqual(site.custom_field_data['choice_field'], original_cfvs['choice_field']) class CustomFieldChoiceAPITest(APITestCase): @@ -514,31 +497,26 @@ class CustomFieldImportTest(TestCase): self.assertEqual(response.status_code, 200) # Validate data for site 1 - custom_field_values = { - cf.name: value for cf, value in Site.objects.get(name='Site 1').custom_field_data - } - self.assertEqual(len(custom_field_values), 6) - self.assertEqual(custom_field_values['text'], 'ABC') - self.assertEqual(custom_field_values['integer'], 123) - self.assertEqual(custom_field_values['boolean'], True) - self.assertEqual(custom_field_values['date'], date(2020, 1, 1)) - self.assertEqual(custom_field_values['url'], 'http://example.com/1') - self.assertEqual(custom_field_values['select'].value, 'Choice A') + site1 = Site.objects.get(name='Site 1') + self.assertEqual(len(site1.custom_field_data), 6) + self.assertEqual(site1.custom_field_data['text'], 'ABC') + self.assertEqual(site1.custom_field_data['integer'], 123) + self.assertEqual(site1.custom_field_data['boolean'], True) + self.assertEqual(site1.custom_field_data['date'], date(2020, 1, 1)) + self.assertEqual(site1.custom_field_data['url'], 'http://example.com/1') + self.assertEqual(site1.custom_field_data['select'].value, 'Choice A') # Validate data for site 2 - custom_field_values = { - cf.name: value for cf, value in Site.objects.get(name='Site 2').custom_field_data - } - self.assertEqual(len(custom_field_values), 6) - self.assertEqual(custom_field_values['text'], 'DEF') - self.assertEqual(custom_field_values['integer'], 456) - self.assertEqual(custom_field_values['boolean'], False) - self.assertEqual(custom_field_values['date'], date(2020, 1, 2)) - self.assertEqual(custom_field_values['url'], 'http://example.com/2') - self.assertEqual(custom_field_values['select'].value, 'Choice B') + site2 = Site.objects.get(name='Site 2') + self.assertEqual(len(site2.custom_field_data), 6) + self.assertEqual(site2.custom_field_data['text'], 'DEF') + self.assertEqual(site2.custom_field_data['integer'], 456) + self.assertEqual(site2.custom_field_data['boolean'], False) + self.assertEqual(site2.custom_field_data['date'], date(2020, 1, 2)) + self.assertEqual(site2.custom_field_data['url'], 'http://example.com/2') + self.assertEqual(site2.custom_field_data['select'].value, 'Choice B') # No CustomFieldValues should be created for site 3 - obj_type = ContentType.objects.get_for_model(Site) site3 = Site.objects.get(name='Site 3') self.assertEqual(site3.custom_field_data, {})