This commit is contained in:
Jeremy Stretch 2023-12-13 15:27:37 -05:00
parent b93735861d
commit 0dcead40e8
5 changed files with 150 additions and 10 deletions

View File

@ -2,16 +2,19 @@ import datetime
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.test import override_settings
from django.urls import reverse from django.urls import reverse
from django.utils.timezone import make_aware from django.utils.timezone import make_aware
from rest_framework import status from rest_framework import status
from circuits.api.serializers import ProviderSerializer
from core.choices import ManagedFileRootPathChoices from core.choices import ManagedFileRootPathChoices
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site
from extras.models import * from extras.models import *
from extras.reports import Report from extras.reports import Report
from extras.scripts import BooleanVar, IntegerVar, Script, StringVar from extras.scripts import BooleanVar, IntegerVar, Script, StringVar
from utilities.testing import APITestCase, APIViewTestCases from ipam.models import ASN, RIR
from utilities.testing import APITestCase, APIViewTestCases, create_tags
User = get_user_model() User = get_user_model()
@ -829,3 +832,53 @@ class ContentTypeTest(APITestCase):
url = reverse('extras-api:contenttype-detail', kwargs={'pk': contenttype.pk}) url = reverse('extras-api:contenttype-detail', kwargs={'pk': contenttype.pk})
self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK) self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK)
class CustomValidationTest(APITestCase):
@override_settings(CUSTOM_VALIDATORS={
'circuits.provider': [
{'tags': {'required': True}}
]
})
def test_tags_validation(self):
"""
Check that custom validation rules work for tag assignment.
"""
data = {
'name': 'Provider 1',
'slug': 'provider-1',
}
serializer = ProviderSerializer(data=data)
self.assertFalse(serializer.is_valid())
tags = create_tags('Tag1', 'Tag2', 'Tag3')
data['tags'] = [tag.pk for tag in tags]
serializer = ProviderSerializer(data=data)
self.assertTrue(serializer.is_valid())
@override_settings(CUSTOM_VALIDATORS={
'circuits.provider': [
{'asns': {'required': True}}
]
})
def test_m2m_validation(self):
"""
Check that custom validation rules work for many-to-many fields.
"""
data = {
'name': 'Provider 1',
'slug': 'provider-1',
}
serializer = ProviderSerializer(data=data)
self.assertFalse(serializer.is_valid())
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
asns = ASN.objects.bulk_create((
ASN(rir=rir, asn=65001),
ASN(rir=rir, asn=65002),
ASN(rir=rir, asn=65003),
))
data['asns'] = [asn.pk for asn in asns]
serializer = ProviderSerializer(data=data)
self.assertTrue(serializer.is_valid())

View File

@ -1,11 +1,14 @@
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.test import TestCase from django.test import TestCase, override_settings
from circuits.forms import ProviderForm
from dcim.forms import SiteForm from dcim.forms import SiteForm
from dcim.models import Site from dcim.models import Site
from extras.choices import CustomFieldTypeChoices from extras.choices import CustomFieldTypeChoices
from extras.forms import SavedFilterForm from extras.forms import SavedFilterForm
from extras.models import CustomField, CustomFieldChoiceSet from extras.models import CustomField, CustomFieldChoiceSet
from ipam.models import ASN, RIR
from utilities.testing import create_tags
class CustomFieldModelFormTest(TestCase): class CustomFieldModelFormTest(TestCase):
@ -109,3 +112,53 @@ class SavedFilterFormTest(TestCase):
}) })
self.assertTrue(form.is_valid()) self.assertTrue(form.is_valid())
form.save() form.save()
class CustomValidationTest(TestCase):
@override_settings(CUSTOM_VALIDATORS={
'circuits.provider': [
{'tags': {'required': True}}
]
})
def test_tags_validation(self):
"""
Check that custom validation rules work for tag assignment.
"""
data = {
'name': 'Provider 1',
'slug': 'provider-1',
}
form = ProviderForm(data)
self.assertFalse(form.is_valid())
tags = create_tags('Tag1', 'Tag2', 'Tag3')
data['tags'] = [tag.pk for tag in tags]
form = ProviderForm(data)
self.assertTrue(form.is_valid())
@override_settings(CUSTOM_VALIDATORS={
'circuits.provider': [
{'asns': {'required': True}}
]
})
def test_m2m_validation(self):
"""
Check that custom validation rules work for many-to-many fields.
"""
data = {
'name': 'Provider 1',
'slug': 'provider-1',
}
form = ProviderForm(data)
self.assertFalse(form.is_valid())
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
asns = ASN.objects.bulk_create((
ASN(rir=rir, asn=65001),
ASN(rir=rir, asn=65002),
ASN(rir=rir, asn=65003),
))
data['asns'] = [asn.pk for asn in asns]
form = ProviderForm(data)
self.assertTrue(form.is_valid())

View File

@ -1,5 +1,6 @@
from django.core.exceptions import ValidationError
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
# NOTE: As this module may be imported by configuration.py, we cannot import # NOTE: As this module may be imported by configuration.py, we cannot import
# anything from NetBox itself. # anything from NetBox itself.
@ -66,8 +67,7 @@ class CustomValidator:
def __call__(self, instance): def __call__(self, instance):
# Validate instance attributes per validation rules # Validate instance attributes per validation rules
for attr_name, rules in self.validation_rules.items(): for attr_name, rules in self.validation_rules.items():
assert hasattr(instance, attr_name), f"Invalid attribute '{attr_name}' for {instance.__class__.__name__}" attr = self._getattr(instance, attr_name)
attr = getattr(instance, attr_name)
for descriptor, value in rules.items(): for descriptor, value in rules.items():
validator = self.get_validator(descriptor, value) validator = self.get_validator(descriptor, value)
try: try:
@ -79,6 +79,28 @@ class CustomValidator:
# Execute custom validation logic (if any) # Execute custom validation logic (if any)
self.validate(instance) self.validate(instance)
@staticmethod
def _getattr(instance, name):
# Attempt to resolve many-to-many fields to their stored values
m2m_fields = [f.name for f in instance._meta.local_many_to_many]
if name in m2m_fields:
if name in instance._m2m_values:
return instance._m2m_values[name]
elif instance.pk:
# TODO: Handle invalid attrs
return list(getattr(instance, name).all())
else:
return []
# Raise a ValidationError for unknown attributes
elif not hasattr(instance, name):
raise ValidationError(_('Invalid attribute "{name}" for {model}').format(
name=name,
model=instance.__class__.__name__
))
return getattr(instance, name)
def get_validator(self, descriptor, value): def get_validator(self, descriptor, value):
""" """
Instantiate and return the appropriate validator based on the descriptor given. For Instantiate and return the appropriate validator based on the descriptor given. For

View File

@ -24,15 +24,15 @@ class ValidatedModelSerializer(BaseModelSerializer):
""" """
def validate(self, data): def validate(self, data):
# Remove custom fields data and tags (if any) prior to model validation # Remove custom fields data (if any) prior to model validation
attrs = data.copy() attrs = data.copy()
attrs.pop('custom_fields', None) attrs.pop('custom_fields', None)
attrs.pop('tags', None)
# Skip ManyToManyFields # Skip ManyToManyFields
for field in self.Meta.model._meta.get_fields(): m2m_values = {}
if isinstance(field, ManyToManyField): for field in self.Meta.model._meta.local_many_to_many:
attrs.pop(field.name, None) if field.name in attrs:
m2m_values[field.name] = attrs.pop(field.name)
# Run clean() on an instance of the model # Run clean() on an instance of the model
if self.instance is None: if self.instance is None:
@ -41,6 +41,7 @@ class ValidatedModelSerializer(BaseModelSerializer):
instance = self.instance instance = self.instance
for k, v in attrs.items(): for k, v in attrs.items():
setattr(instance, k, v) setattr(instance, k, v)
instance._m2m_values = m2m_values
instance.full_clean() instance.full_clean()
return data return data

View File

@ -57,6 +57,17 @@ class NetBoxModelForm(BootstrapMixin, CheckLastUpdatedMixin, CustomFieldsMixin,
return super().clean() return super().clean()
def _post_clean(self):
"""
Override BaseModelForm's _post_clean() to store many-to-many field values on the model instance.
"""
self.instance._m2m_values = {}
for field in self.instance._meta.local_many_to_many:
if field.name in self.cleaned_data:
self.instance._m2m_values[field.name] = list(self.cleaned_data[field.name])
return super()._post_clean()
class NetBoxModelImportForm(CSVModelForm, NetBoxModelForm): class NetBoxModelImportForm(CSVModelForm, NetBoxModelForm):
""" """