Move validate_custom_field_data() into assertInstanceEqual()

This commit is contained in:
Jeremy Stretch 2024-04-17 15:51:56 -04:00
parent a2432df603
commit 8417b28fb3
3 changed files with 9 additions and 12 deletions

View File

@ -10,10 +10,11 @@ from django.test import Client, TestCase as _TestCase
from netaddr import IPNetwork from netaddr import IPNetwork
from taggit.managers import TaggableManager from taggit.managers import TaggableManager
from netbox.models.features import CustomFieldsMixin
from users.models import ObjectPermission from users.models import ObjectPermission
from utilities.permissions import resolve_permission_ct from utilities.permissions import resolve_permission_ct
from utilities.utils import content_type_identifier from utilities.utils import content_type_identifier
from .utils import extract_form_failures from .utils import DUMMY_CF_DATA, extract_form_failures
__all__ = ( __all__ = (
'ModelTestCase', 'ModelTestCase',
@ -166,8 +167,12 @@ class ModelTestCase(TestCase):
model_dict = self.model_to_dict(instance, fields=fields, api=api) model_dict = self.model_to_dict(instance, fields=fields, api=api)
# Omit any dictionary keys which are not instance attributes or have been excluded # Omit any dictionary keys which are not instance attributes or have been excluded
relevant_data = { model_data = {
k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude
} }
self.assertDictEqual(model_dict, relevant_data) self.assertDictEqual(model_dict, model_data)
# Validate any custom field data, if present
if getattr(instance, 'custom_field_data', None):
self.assertDictEqual(instance.custom_field_data, DUMMY_CF_DATA)

View File

@ -144,7 +144,3 @@ def add_custom_field_data(form_data, model):
f'cf_{k}': v if type(v) is str else json.dumps(v) f'cf_{k}': v if type(v) is str else json.dumps(v)
for k, v in DUMMY_CF_DATA.items() for k, v in DUMMY_CF_DATA.items()
}) })
def validate_custom_field_data(test_case, instance):
test_case.assertDictEqual(instance.cf, DUMMY_CF_DATA)

View File

@ -14,7 +14,7 @@ from netbox.models.features import ChangeLoggingMixin, CustomFieldsMixin
from users.models import ObjectPermission from users.models import ObjectPermission
from utilities.choices import CSVDelimiterChoices, ImportFormatChoices from utilities.choices import CSVDelimiterChoices, ImportFormatChoices
from .base import ModelTestCase from .base import ModelTestCase
from .utils import add_custom_field_data, disable_warnings, post_data, validate_custom_field_data from .utils import add_custom_field_data, disable_warnings, post_data
__all__ = ( __all__ = (
'ModelViewTestCase', 'ModelViewTestCase',
@ -179,8 +179,6 @@ class ViewTestCases:
self.assertEqual(initial_count + 1, self._get_queryset().count()) self.assertEqual(initial_count + 1, self._get_queryset().count())
instance = self._get_queryset().order_by('pk').last() instance = self._get_queryset().order_by('pk').last()
self.assertInstanceEqual(instance, self.form_data, exclude=self.validation_excluded_fields) self.assertInstanceEqual(instance, self.form_data, exclude=self.validation_excluded_fields)
if issubclass(self.model, CustomFieldsMixin):
validate_custom_field_data(self, instance)
# Verify ObjectChange creation # Verify ObjectChange creation
if issubclass(instance.__class__, ChangeLoggingMixin): if issubclass(instance.__class__, ChangeLoggingMixin):
@ -282,8 +280,6 @@ class ViewTestCases:
self.assertHttpStatus(self.client.post(**request), 302) self.assertHttpStatus(self.client.post(**request), 302)
instance = self._get_queryset().get(pk=instance.pk) instance = self._get_queryset().get(pk=instance.pk)
self.assertInstanceEqual(instance, self.form_data, exclude=self.validation_excluded_fields) self.assertInstanceEqual(instance, self.form_data, exclude=self.validation_excluded_fields)
if issubclass(self.model, CustomFieldsMixin):
validate_custom_field_data(self, instance)
# Verify ObjectChange creation # Verify ObjectChange creation
if issubclass(instance.__class__, ChangeLoggingMixin): if issubclass(instance.__class__, ChangeLoggingMixin):