Add tests for missing FilterSet filters

This commit is contained in:
Jeremy Stretch 2024-03-05 17:14:42 -05:00
parent d6acc18c29
commit 6af12b1814
9 changed files with 126 additions and 16 deletions

View File

@ -330,6 +330,7 @@ class CircuitTestCase(TestCase, ChangeLoggedFilterSetTests):
class CircuitTerminationTestCase(TestCase, ChangeLoggedFilterSetTests): class CircuitTerminationTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = CircuitTermination.objects.all() queryset = CircuitTermination.objects.all()
filterset = CircuitTerminationFilterSet filterset = CircuitTerminationFilterSet
ignore_fields = ('cable',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -10,6 +10,7 @@ from ..models import *
class DataSourceTestCase(TestCase, ChangeLoggedFilterSetTests): class DataSourceTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = DataSource.objects.all() queryset = DataSource.objects.all()
filterset = DataSourceFilterSet filterset = DataSourceFilterSet
ignore_fields = ('ignore_rules', 'parameters')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -70,6 +71,7 @@ class DataSourceTestCase(TestCase, ChangeLoggedFilterSetTests):
class DataFileTestCase(TestCase, ChangeLoggedFilterSetTests): class DataFileTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = DataFile.objects.all() queryset = DataFile.objects.all()
filterset = DataFileFilterSet filterset = DataFileFilterSet
ignore_fields = ('data',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -196,6 +196,7 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
class SiteTestCase(TestCase, ChangeLoggedFilterSetTests): class SiteTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Site.objects.all() queryset = Site.objects.all()
filterset = SiteFilterSet filterset = SiteFilterSet
ignore_fields = ('physical_address', 'shipping_address')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -467,6 +468,7 @@ class RackRoleTestCase(TestCase, ChangeLoggedFilterSetTests):
class RackTestCase(TestCase, ChangeLoggedFilterSetTests): class RackTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Rack.objects.all() queryset = Rack.objects.all()
filterset = RackFilterSet filterset = RackFilterSet
ignore_fields = ('units',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -726,6 +728,7 @@ class RackTestCase(TestCase, ChangeLoggedFilterSetTests):
class RackReservationTestCase(TestCase, ChangeLoggedFilterSetTests): class RackReservationTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = RackReservation.objects.all() queryset = RackReservation.objects.all()
filterset = RackReservationFilterSet filterset = RackReservationFilterSet
ignore_fields = ('units',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -889,6 +892,7 @@ class ManufacturerTestCase(TestCase, ChangeLoggedFilterSetTests):
class DeviceTypeTestCase(TestCase, ChangeLoggedFilterSetTests): class DeviceTypeTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = DeviceType.objects.all() queryset = DeviceType.objects.all()
filterset = DeviceTypeFilterSet filterset = DeviceTypeFilterSet
ignore_fields = ('front_image', 'rear_image')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -1880,6 +1884,7 @@ class PlatformTestCase(TestCase, ChangeLoggedFilterSetTests):
class DeviceTestCase(TestCase, ChangeLoggedFilterSetTests): class DeviceTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Device.objects.all() queryset = Device.objects.all()
filterset = DeviceFilterSet filterset = DeviceFilterSet
ignore_fields = ('primary_ip4', 'primary_ip6', 'oob_ip', 'local_context_data')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -2332,6 +2337,7 @@ class DeviceTestCase(TestCase, ChangeLoggedFilterSetTests):
class ModuleTestCase(TestCase, ChangeLoggedFilterSetTests): class ModuleTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Module.objects.all() queryset = Module.objects.all()
filterset = ModuleFilterSet filterset = ModuleFilterSet
ignore_fields = ('local_context_data',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -3229,6 +3235,7 @@ class PowerOutletTestCase(TestCase, DeviceComponentFilterSetTests, ChangeLoggedF
class InterfaceTestCase(TestCase, DeviceComponentFilterSetTests, ChangeLoggedFilterSetTests): class InterfaceTestCase(TestCase, DeviceComponentFilterSetTests, ChangeLoggedFilterSetTests):
queryset = Interface.objects.all() queryset = Interface.objects.all()
filterset = InterfaceFilterSet filterset = InterfaceFilterSet
ignore_fields = ('untagged_vlan',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -5332,6 +5339,7 @@ class PowerFeedTestCase(TestCase, ChangeLoggedFilterSetTests):
class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests): class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VirtualDeviceContext.objects.all() queryset = VirtualDeviceContext.objects.all()
filterset = VirtualDeviceContextFilterSet filterset = VirtualDeviceContextFilterSet
ignore_fields = ('primary_ip4', 'primary_ip6')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -23,9 +23,10 @@ from virtualization.models import Cluster, ClusterGroup, ClusterType
User = get_user_model() User = get_user_model()
class CustomFieldTestCase(TestCase, BaseFilterSetTests): class CustomFieldTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = CustomField.objects.all() queryset = CustomField.objects.all()
filterset = CustomFieldFilterSet filterset = CustomFieldFilterSet
ignore_fields = ('default',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -155,9 +156,10 @@ class CustomFieldTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class CustomFieldChoiceSetTestCase(TestCase, BaseFilterSetTests): class CustomFieldChoiceSetTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = CustomFieldChoiceSet.objects.all() queryset = CustomFieldChoiceSet.objects.all()
filterset = CustomFieldChoiceSetFilterSet filterset = CustomFieldChoiceSetFilterSet
ignore_fields = ('extra_choices',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -188,6 +190,7 @@ class CustomFieldChoiceSetTestCase(TestCase, BaseFilterSetTests):
class WebhookTestCase(TestCase, BaseFilterSetTests): class WebhookTestCase(TestCase, BaseFilterSetTests):
queryset = Webhook.objects.all() queryset = Webhook.objects.all()
filterset = WebhookFilterSet filterset = WebhookFilterSet
ignore_fields = ('additional_headers', 'body_template')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -252,6 +255,7 @@ class WebhookTestCase(TestCase, BaseFilterSetTests):
class EventRuleTestCase(TestCase, BaseFilterSetTests): class EventRuleTestCase(TestCase, BaseFilterSetTests):
queryset = EventRule.objects.all() queryset = EventRule.objects.all()
filterset = EventRuleFilterSet filterset = EventRuleFilterSet
ignore_fields = ('action_data', 'conditions')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -405,7 +409,7 @@ class EventRuleTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
class CustomLinkTestCase(TestCase, BaseFilterSetTests): class CustomLinkTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = CustomLink.objects.all() queryset = CustomLink.objects.all()
filterset = CustomLinkFilterSet filterset = CustomLinkFilterSet
@ -474,9 +478,10 @@ class CustomLinkTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
class SavedFilterTestCase(TestCase, BaseFilterSetTests): class SavedFilterTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = SavedFilter.objects.all() queryset = SavedFilter.objects.all()
filterset = SavedFilterFilterSet filterset = SavedFilterFilterSet
ignore_fields = ('parameters',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -647,9 +652,10 @@ class BookmarkTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
class ExportTemplateTestCase(TestCase, BaseFilterSetTests): class ExportTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = ExportTemplate.objects.all() queryset = ExportTemplate.objects.all()
filterset = ExportTemplateFilterSet filterset = ExportTemplateFilterSet
ignore_fields = ('template_code', 'data_path')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -683,9 +689,10 @@ class ExportTemplateTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class ImageAttachmentTestCase(TestCase, BaseFilterSetTests): class ImageAttachmentTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = ImageAttachment.objects.all() queryset = ImageAttachment.objects.all()
filterset = ImageAttachmentFilterSet filterset = ImageAttachmentFilterSet
ignore_fields = ('image',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -760,12 +767,6 @@ class ImageAttachmentTestCase(TestCase, BaseFilterSetTests):
} }
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_created(self):
pk_list = self.queryset.values_list('pk', flat=True)[:2]
self.queryset.filter(pk__in=pk_list).update(created=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc))
params = {'created': '2021-01-01T00:00:00'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests): class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = JournalEntry.objects.all() queryset = JournalEntry.objects.all()
@ -873,6 +874,7 @@ class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests):
class ConfigContextTestCase(TestCase, ChangeLoggedFilterSetTests): class ConfigContextTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = ConfigContext.objects.all() queryset = ConfigContext.objects.all()
filterset = ConfigContextFilterSet filterset = ConfigContextFilterSet
ignore_fields = ('data', 'data_path')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -1096,9 +1098,10 @@ class ConfigContextTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class ConfigTemplateTestCase(TestCase, BaseFilterSetTests): class ConfigTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = ConfigTemplate.objects.all() queryset = ConfigTemplate.objects.all()
filterset = ConfigTemplateFilterSet filterset = ConfigTemplateFilterSet
ignore_fields = ('template_code', 'environment_params', 'data_path')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -1193,6 +1196,7 @@ class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
class ObjectChangeTestCase(TestCase, BaseFilterSetTests): class ObjectChangeTestCase(TestCase, BaseFilterSetTests):
queryset = ObjectChange.objects.all() queryset = ObjectChange.objects.all()
filterset = ObjectChangeFilterSet filterset = ObjectChangeFilterSet
ignore_fields = ('prechange_data', 'postchange_data')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -1733,6 +1733,7 @@ class VLANTestCase(TestCase, ChangeLoggedFilterSetTests):
class ServiceTemplateTestCase(TestCase, ChangeLoggedFilterSetTests): class ServiceTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = ServiceTemplate.objects.all() queryset = ServiceTemplate.objects.all()
filterset = ServiceTemplateFilterSet filterset = ServiceTemplateFilterSet
ignore_fields = ('ports',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -1797,6 +1798,7 @@ class ServiceTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
class ServiceTestCase(TestCase, ChangeLoggedFilterSetTests): class ServiceTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Service.objects.all() queryset = Service.objects.all()
filterset = ServiceFilterSet filterset = ServiceFilterSet
ignore_fields = ('ports',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -15,6 +15,7 @@ User = get_user_model()
class UserTestCase(TestCase, BaseFilterSetTests): class UserTestCase(TestCase, BaseFilterSetTests):
queryset = User.objects.all() queryset = User.objects.all()
filterset = filtersets.UserFilterSet filterset = filtersets.UserFilterSet
ignore_fields = ('password',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -132,6 +133,7 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
class ObjectPermissionTestCase(TestCase, BaseFilterSetTests): class ObjectPermissionTestCase(TestCase, BaseFilterSetTests):
queryset = ObjectPermission.objects.all() queryset = ObjectPermission.objects.all()
filterset = filtersets.ObjectPermissionFilterSet filterset = filtersets.ObjectPermissionFilterSet
ignore_fields = ('actions', 'constraints')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -226,6 +228,7 @@ class ObjectPermissionTestCase(TestCase, BaseFilterSetTests):
class TokenTestCase(TestCase, BaseFilterSetTests): class TokenTestCase(TestCase, BaseFilterSetTests):
queryset = Token.objects.all() queryset = Token.objects.all()
filterset = filtersets.TokenFilterSet filterset = filtersets.TokenFilterSet
ignore_fields = ('allowed_ips',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -1,15 +1,47 @@
from datetime import date, datetime, timezone from datetime import datetime, timezone
from itertools import chain
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
from django.utils.module_loading import import_string
from taggit.managers import TaggableManager
from core.models import ObjectType
__all__ = ( __all__ = (
'BaseFilterSetTests', 'BaseFilterSetTests',
'ChangeLoggedFilterSetTests', 'ChangeLoggedFilterSetTests',
) )
IGNORE_MODELS = (
('core', 'AutoSyncRecord'),
('core', 'ManagedFile'),
('core', 'ObjectType'),
('dcim', 'CablePath'),
('extras', 'Branch'),
('extras', 'CachedValue'),
('extras', 'Dashboard'),
('extras', 'ScriptModule'),
('extras', 'StagedChange'),
('extras', 'TaggedItem'),
('users', 'UserConfig'),
)
IGNORE_FIELDS = (
'comments',
'custom_field_data',
'level', # MPTT
'lft', # MPTT
'rght', # MPTT
'tree_id', # MPTT
)
class BaseFilterSetTests: class BaseFilterSetTests:
queryset = None queryset = None
filterset = None filterset = None
ignore_fields = tuple()
def test_id(self): def test_id(self):
""" """
@ -19,6 +51,63 @@ class BaseFilterSetTests:
self.assertGreater(self.queryset.count(), 2) self.assertGreater(self.queryset.count(), 2)
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_missing_filters(self):
"""
Check for any model fields which do not have the required filter(s) defined.
"""
app_label = self.__class__.__module__.split('.')[0]
model = self.queryset.model
model_name = model.__name__
# Skip ignored models
if (app_label, model_name) in IGNORE_MODELS:
return
# Import the FilterSet class & sanity check it
filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
filterset_fields = sorted(filterset.get_filters())
# Check for missing filters
for model_field in model._meta.get_fields():
# Skip private fields
if model_field.name.startswith('_'):
continue
# Skip ignored fields
if model_field.name in chain(self.ignore_fields, IGNORE_FIELDS):
continue
# One-to-one & one-to-many relationships
if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
if model_field.related_model is ContentType:
# Relationships to ContentType (used as part of a GFK) do not need a filter
continue
elif model_field.related_model is ObjectType:
# Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
filter_name = model_field.name
else:
filter_name = f'{model_field.name}_id'
self.assertIn(filter_name, filterset_fields, f'No filter found for {model_field.name}!')
# TODO: Many-to-one & many-to-many relationships
elif type(model_field) in (ManyToOneRel, ManyToManyField, ManyToManyRel):
continue
# TODO: Generic relationships
elif type(model_field) in (GenericForeignKey, GenericRelation):
continue
# Tags
elif type(model_field) is TaggableManager:
self.assertIn('tag', filterset_fields, f'No filter found for {model_field.name}!')
# All other fields
else:
self.assertIn(model_field.name, filterset_fields, f'No filter found for {model_field.name}!')
class ChangeLoggedFilterSetTests(BaseFilterSetTests): class ChangeLoggedFilterSetTests(BaseFilterSetTests):

View File

@ -522,6 +522,7 @@ class VirtualMachineTestCase(TestCase, ChangeLoggedFilterSetTests):
class VMInterfaceTestCase(TestCase, ChangeLoggedFilterSetTests): class VMInterfaceTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VMInterface.objects.all() queryset = VMInterface.objects.all()
filterset = VMInterfaceFilterSet filterset = VMInterfaceFilterSet
ignore_fields = ('untagged_vlan',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -848,8 +848,8 @@ class L2VPNTerminationTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'l2vpn': [l2vpns[0].slug, l2vpns[1].slug]} params = {'l2vpn': [l2vpns[0].slug, l2vpns[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
def test_content_type(self): def test_termination_type(self):
params = {'assigned_object_type_id': ContentType.objects.get(model='vlan').pk} params = {'assigned_object_type': 'ipam.vlan'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
def test_interface(self): def test_interface(self):