diff --git a/netbox/extras/tests/test_api.py b/netbox/extras/tests/test_api.py index 255457f21..52c988ed3 100644 --- a/netbox/extras/tests/test_api.py +++ b/netbox/extras/tests/test_api.py @@ -2,16 +2,19 @@ import datetime from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType +from django.test import override_settings from django.urls import reverse from django.utils.timezone import make_aware from rest_framework import status +from circuits.api.serializers import ProviderSerializer from core.choices import ManagedFileRootPathChoices from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site from extras.models import * from extras.reports import Report from extras.scripts import BooleanVar, IntegerVar, Script, StringVar -from utilities.testing import APITestCase, APIViewTestCases +from ipam.models import ASN, RIR +from utilities.testing import APITestCase, APIViewTestCases, create_tags User = get_user_model() @@ -829,3 +832,53 @@ class ContentTypeTest(APITestCase): url = reverse('extras-api:contenttype-detail', kwargs={'pk': contenttype.pk}) self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK) + + +class CustomValidationTest(APITestCase): + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'tags': {'required': True}} + ] + }) + def test_tags_validation(self): + """ + Check that custom validation rules work for tag assignment. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + serializer = ProviderSerializer(data=data) + self.assertFalse(serializer.is_valid()) + + tags = create_tags('Tag1', 'Tag2', 'Tag3') + data['tags'] = [tag.pk for tag in tags] + serializer = ProviderSerializer(data=data) + self.assertTrue(serializer.is_valid()) + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'asns': {'required': True}} + ] + }) + def test_m2m_validation(self): + """ + Check that custom validation rules work for many-to-many fields. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + serializer = ProviderSerializer(data=data) + self.assertFalse(serializer.is_valid()) + + rir = RIR.objects.create(name='RIR 1', slug='rir-1') + asns = ASN.objects.bulk_create(( + ASN(rir=rir, asn=65001), + ASN(rir=rir, asn=65002), + ASN(rir=rir, asn=65003), + )) + data['asns'] = [asn.pk for asn in asns] + serializer = ProviderSerializer(data=data) + self.assertTrue(serializer.is_valid()) diff --git a/netbox/extras/tests/test_forms.py b/netbox/extras/tests/test_forms.py index 9c22bf83c..ac94dba64 100644 --- a/netbox/extras/tests/test_forms.py +++ b/netbox/extras/tests/test_forms.py @@ -1,11 +1,14 @@ from django.contrib.contenttypes.models import ContentType -from django.test import TestCase +from django.test import TestCase, override_settings +from circuits.forms import ProviderForm from dcim.forms import SiteForm from dcim.models import Site from extras.choices import CustomFieldTypeChoices from extras.forms import SavedFilterForm from extras.models import CustomField, CustomFieldChoiceSet +from ipam.models import ASN, RIR +from utilities.testing import create_tags class CustomFieldModelFormTest(TestCase): @@ -109,3 +112,53 @@ class SavedFilterFormTest(TestCase): }) self.assertTrue(form.is_valid()) form.save() + + +class CustomValidationTest(TestCase): + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'tags': {'required': True}} + ] + }) + def test_tags_validation(self): + """ + Check that custom validation rules work for tag assignment. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + form = ProviderForm(data) + self.assertFalse(form.is_valid()) + + tags = create_tags('Tag1', 'Tag2', 'Tag3') + data['tags'] = [tag.pk for tag in tags] + form = ProviderForm(data) + self.assertTrue(form.is_valid()) + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'asns': {'required': True}} + ] + }) + def test_m2m_validation(self): + """ + Check that custom validation rules work for many-to-many fields. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + form = ProviderForm(data) + self.assertFalse(form.is_valid()) + + rir = RIR.objects.create(name='RIR 1', slug='rir-1') + asns = ASN.objects.bulk_create(( + ASN(rir=rir, asn=65001), + ASN(rir=rir, asn=65002), + ASN(rir=rir, asn=65003), + )) + data['asns'] = [asn.pk for asn in asns] + form = ProviderForm(data) + self.assertTrue(form.is_valid()) diff --git a/netbox/extras/validators.py b/netbox/extras/validators.py index 686c9b032..82e11f3f9 100644 --- a/netbox/extras/validators.py +++ b/netbox/extras/validators.py @@ -1,5 +1,6 @@ -from django.core.exceptions import ValidationError from django.core import validators +from django.core.exceptions import ValidationError +from django.utils.translation import gettext_lazy as _ # NOTE: As this module may be imported by configuration.py, we cannot import # anything from NetBox itself. @@ -66,8 +67,7 @@ class CustomValidator: def __call__(self, instance): # Validate instance attributes per validation rules for attr_name, rules in self.validation_rules.items(): - assert hasattr(instance, attr_name), f"Invalid attribute '{attr_name}' for {instance.__class__.__name__}" - attr = getattr(instance, attr_name) + attr = self._getattr(instance, attr_name) for descriptor, value in rules.items(): validator = self.get_validator(descriptor, value) try: @@ -79,6 +79,28 @@ class CustomValidator: # Execute custom validation logic (if any) self.validate(instance) + @staticmethod + def _getattr(instance, name): + # Attempt to resolve many-to-many fields to their stored values + m2m_fields = [f.name for f in instance._meta.local_many_to_many] + if name in m2m_fields: + if name in instance._m2m_values: + return instance._m2m_values[name] + elif instance.pk: + # TODO: Handle invalid attrs + return list(getattr(instance, name).all()) + else: + return [] + + # Raise a ValidationError for unknown attributes + elif not hasattr(instance, name): + raise ValidationError(_('Invalid attribute "{name}" for {model}').format( + name=name, + model=instance.__class__.__name__ + )) + + return getattr(instance, name) + def get_validator(self, descriptor, value): """ Instantiate and return the appropriate validator based on the descriptor given. For diff --git a/netbox/netbox/api/serializers/base.py b/netbox/netbox/api/serializers/base.py index 5ee74bf8c..b2809cf3d 100644 --- a/netbox/netbox/api/serializers/base.py +++ b/netbox/netbox/api/serializers/base.py @@ -24,15 +24,15 @@ class ValidatedModelSerializer(BaseModelSerializer): """ def validate(self, data): - # Remove custom fields data and tags (if any) prior to model validation + # Remove custom fields data (if any) prior to model validation attrs = data.copy() attrs.pop('custom_fields', None) - attrs.pop('tags', None) # Skip ManyToManyFields - for field in self.Meta.model._meta.get_fields(): - if isinstance(field, ManyToManyField): - attrs.pop(field.name, None) + m2m_values = {} + for field in self.Meta.model._meta.local_many_to_many: + if field.name in attrs: + m2m_values[field.name] = attrs.pop(field.name) # Run clean() on an instance of the model if self.instance is None: @@ -41,6 +41,7 @@ class ValidatedModelSerializer(BaseModelSerializer): instance = self.instance for k, v in attrs.items(): setattr(instance, k, v) + instance._m2m_values = m2m_values instance.full_clean() return data diff --git a/netbox/netbox/forms/base.py b/netbox/netbox/forms/base.py index 43d0850f0..4720162f1 100644 --- a/netbox/netbox/forms/base.py +++ b/netbox/netbox/forms/base.py @@ -57,6 +57,17 @@ class NetBoxModelForm(BootstrapMixin, CheckLastUpdatedMixin, CustomFieldsMixin, return super().clean() + def _post_clean(self): + """ + Override BaseModelForm's _post_clean() to store many-to-many field values on the model instance. + """ + self.instance._m2m_values = {} + for field in self.instance._meta.local_many_to_many: + if field.name in self.cleaned_data: + self.instance._m2m_values[field.name] = list(self.cleaned_data[field.name]) + + return super()._post_clean() + class NetBoxModelImportForm(CSVModelForm, NetBoxModelForm): """