diff --git a/netbox/extras/api/customfields.py b/netbox/extras/api/customfields.py index a053642db..393e4545f 100644 --- a/netbox/extras/api/customfields.py +++ b/netbox/extras/api/customfields.py @@ -1,10 +1,9 @@ from datetime import datetime from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import ObjectDoesNotExist from rest_framework import serializers from rest_framework.exceptions import ValidationError -from rest_framework.fields import CreateOnlyDefault +from rest_framework.fields import CreateOnlyDefault, Field from extras.choices import * from extras.models import CustomField @@ -46,12 +45,18 @@ class CustomFieldDefaultValues: return value -class CustomFieldsSerializer(serializers.BaseSerializer): +class CustomFieldsDataField(Field): def to_representation(self, obj): - return obj + content_type = ContentType.objects.get_for_model(self.parent.Meta.model) + custom_fields = CustomField.objects.filter(obj_type=content_type) + + return {cf.name: obj.get(cf.name) for cf in custom_fields} def to_internal_value(self, data): + # If updating an existing instance, start with existing custom_field_data + if self.parent.instance: + data = {**self.parent.instance.custom_field_data, **data} content_type = ContentType.objects.get_for_model(self.parent.Meta.model) custom_fields = { @@ -111,9 +116,8 @@ class CustomFieldModelSerializer(ValidatedModelSerializer): """ Extends ModelSerializer to render any CustomFields and their values associated with an object. """ - custom_fields = CustomFieldsSerializer( + custom_fields = CustomFieldsDataField( source='custom_field_data', - required=False, default=CreateOnlyDefault(CustomFieldDefaultValues()) ) diff --git a/netbox/extras/tests/test_customfields.py b/netbox/extras/tests/test_customfields.py index 675248a3b..31d3c2be9 100644 --- a/netbox/extras/tests/test_customfields.py +++ b/netbox/extras/tests/test_customfields.py @@ -393,19 +393,17 @@ class CustomFieldAPITest(APITestCase): # Validate response data response_cf = response.data['custom_fields'] - data_cf = data['custom_fields'] - self.assertEqual(response_cf['text_field'], data_cf['text_field']) - self.assertEqual(response_cf['number_field'], data_cf['number_field']) - # TODO: Non-updated fields are missing from the response data - # self.assertEqual(response_cf['boolean_field'], site2_original_cfvs['boolean_field']) - # self.assertEqual(response_cf['date_field'], site2_original_cfvs['date_field']) - # self.assertEqual(response_cf['url_field'], site2_original_cfvs['url_field']) - # self.assertEqual(response_cf['choice_field'], site2_original_cfvs['choice_field'].value) + self.assertEqual(response_cf['text_field'], data['custom_fields']['text_field']) + self.assertEqual(response_cf['number_field'], data['custom_fields']['number_field']) + self.assertEqual(response_cf['boolean_field'], original_cfvs['boolean_field']) + self.assertEqual(response_cf['date_field'], original_cfvs['date_field']) + self.assertEqual(response_cf['url_field'], original_cfvs['url_field']) + self.assertEqual(response_cf['choice_field'], original_cfvs['choice_field']) # Validate database data site.refresh_from_db() - self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field']) - self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field']) + self.assertEqual(site.custom_field_data['text_field'], data['custom_fields']['text_field']) + self.assertEqual(site.custom_field_data['number_field'], data['custom_fields']['number_field']) self.assertEqual(site.custom_field_data['boolean_field'], original_cfvs['boolean_field']) self.assertEqual(site.custom_field_data['date_field'], original_cfvs['date_field']) self.assertEqual(site.custom_field_data['url_field'], original_cfvs['url_field']) @@ -456,7 +454,7 @@ class CustomFieldImportTest(TestCase): self.assertEqual(site1.custom_field_data['boolean'], True) self.assertEqual(site1.custom_field_data['date'], '2020-01-01') self.assertEqual(site1.custom_field_data['url'], 'http://example.com/1') - self.assertEqual(site1.custom_field_data['select'].value, 'Choice A') + self.assertEqual(site1.custom_field_data['select'], 'Choice A') # Validate data for site 2 site2 = Site.objects.get(name='Site 2') @@ -466,7 +464,7 @@ class CustomFieldImportTest(TestCase): self.assertEqual(site2.custom_field_data['boolean'], False) self.assertEqual(site2.custom_field_data['date'], '2020-01-02') self.assertEqual(site2.custom_field_data['url'], 'http://example.com/2') - self.assertEqual(site2.custom_field_data['select'].value, 'Choice B') + self.assertEqual(site2.custom_field_data['select'], 'Choice B') # No CustomFieldValues should be created for site 3 site3 = Site.objects.get(name='Site 3')