Test for missing ManyToManyField filters

This commit is contained in:
Jeremy Stretch 2024-03-07 14:59:41 -05:00
parent 0a0dae3d35
commit 6085e0bb0b
11 changed files with 101 additions and 31 deletions

View File

@ -23,6 +23,7 @@ from utilities.filters import (
from virtualization.models import Cluster from virtualization.models import Cluster
from vpn.models import L2VPN from vpn.models import L2VPN
from wireless.choices import WirelessRoleChoices, WirelessChannelChoices from wireless.choices import WirelessRoleChoices, WirelessChannelChoices
from wireless.models import WirelessLAN, WirelessLink
from .choices import * from .choices import *
from .constants import * from .constants import *
from .models import * from .models import *
@ -1637,13 +1638,22 @@ class InterfaceFilterSet(
to_field_name='name', to_field_name='name',
label='Virtual Device Context', label='Virtual Device Context',
) )
wireless_lan_id = django_filters.ModelMultipleChoiceFilter(
field_name='wireless_lans',
queryset=WirelessLAN.objects.all(),
label='Wireless LAN',
)
wireless_link_id = django_filters.ModelMultipleChoiceFilter(
queryset=WirelessLink.objects.all(),
label='Wireless link',
)
class Meta: class Meta:
model = Interface model = Interface
fields = ( fields = (
'id', 'name', 'label', 'type', 'enabled', 'mtu', 'mgmt_only', 'poe_mode', 'poe_type', 'mode', 'rf_role', 'id', 'name', 'label', 'type', 'enabled', 'mtu', 'mgmt_only', 'poe_mode', 'poe_type', 'mode', 'rf_role',
'rf_channel', 'rf_channel_frequency', 'rf_channel_width', 'tx_power', 'description', 'mark_connected', 'rf_channel', 'rf_channel_frequency', 'rf_channel_width', 'tx_power', 'description', 'mark_connected',
'cable_id', 'cable_end', 'wireless_link_id', 'cable_id', 'cable_end',
) )
def filter_virtual_chassis_member(self, queryset, name, value): def filter_virtual_chassis_member(self, queryset, name, value):

View File

@ -3235,7 +3235,7 @@ class PowerOutletTestCase(TestCase, DeviceComponentFilterSetTests, ChangeLoggedF
class InterfaceTestCase(TestCase, DeviceComponentFilterSetTests, ChangeLoggedFilterSetTests): class InterfaceTestCase(TestCase, DeviceComponentFilterSetTests, ChangeLoggedFilterSetTests):
queryset = Interface.objects.all() queryset = Interface.objects.all()
filterset = InterfaceFilterSet filterset = InterfaceFilterSet
ignore_fields = ('untagged_vlan',) ignore_fields = ('untagged_vlan', 'vdcs')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -491,12 +491,12 @@ class ConfigContextFilterSet(ChangeLoggedModelFilterSet):
queryset=DeviceType.objects.all(), queryset=DeviceType.objects.all(),
label=_('Device type'), label=_('Device type'),
) )
role_id = django_filters.ModelMultipleChoiceFilter( device_role_id = django_filters.ModelMultipleChoiceFilter(
field_name='roles', field_name='roles',
queryset=DeviceRole.objects.all(), queryset=DeviceRole.objects.all(),
label=_('Role'), label=_('Role'),
) )
role = django_filters.ModelMultipleChoiceFilter( device_role = django_filters.ModelMultipleChoiceFilter(
field_name='roles__slug', field_name='roles__slug',
queryset=DeviceRole.objects.all(), queryset=DeviceRole.objects.all(),
to_field_name='slug', to_field_name='slug',
@ -582,6 +582,10 @@ class ConfigContextFilterSet(ChangeLoggedModelFilterSet):
label=_('Data file (ID)'), label=_('Data file (ID)'),
) )
# TODO: Remove in v4.1
role = device_role
role_id = device_role_id
class Meta: class Meta:
model = ConfigContext model = ConfigContext
fields = ('id', 'name', 'is_active', 'description', 'weight', 'auto_sync_enabled', 'data_synced') fields = ('id', 'name', 'is_active', 'description', 'weight', 'auto_sync_enabled', 'data_synced')

View File

@ -1043,11 +1043,11 @@ class ConfigContextTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'device_type_id': [device_types[0].pk, device_types[1].pk]} params = {'device_type_id': [device_types[0].pk, device_types[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_role(self): def test_device_role(self):
device_roles = DeviceRole.objects.all()[:2] device_roles = DeviceRole.objects.all()[:2]
params = {'role_id': [device_roles[0].pk, device_roles[1].pk]} params = {'device_role_id': [device_roles[0].pk, device_roles[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'role': [device_roles[0].slug, device_roles[1].slug]} params = {'device_role': [device_roles[0].slug, device_roles[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_platform(self): def test_platform(self):
@ -1128,6 +1128,7 @@ class ConfigTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
class TagTestCase(TestCase, ChangeLoggedFilterSetTests): class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Tag.objects.all() queryset = Tag.objects.all()
filterset = TagFilterSet filterset = TagFilterSet
ignore_fields = ('object_types',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -1046,12 +1046,12 @@ class ServiceFilterSet(NetBoxModelFilterSet):
to_field_name='name', to_field_name='name',
label=_('Virtual machine (name)'), label=_('Virtual machine (name)'),
) )
ipaddress_id = django_filters.ModelMultipleChoiceFilter( ip_address_id = django_filters.ModelMultipleChoiceFilter(
field_name='ipaddresses', field_name='ipaddresses',
queryset=IPAddress.objects.all(), queryset=IPAddress.objects.all(),
label=_('IP address (ID)'), label=_('IP address (ID)'),
) )
ipaddress = django_filters.ModelMultipleChoiceFilter( ip_address = django_filters.ModelMultipleChoiceFilter(
field_name='ipaddresses__address', field_name='ipaddresses__address',
queryset=IPAddress.objects.all(), queryset=IPAddress.objects.all(),
to_field_name='address', to_field_name='address',
@ -1062,6 +1062,10 @@ class ServiceFilterSet(NetBoxModelFilterSet):
lookup_expr='contains' lookup_expr='contains'
) )
# TODO: Remove in v4.1
ipaddress = ip_address
ipaddress_id = ip_address_id
class Meta: class Meta:
model = Service model = Service
fields = ('id', 'name', 'protocol', 'description') fields = ('id', 'name', 'protocol', 'description')

View File

@ -181,6 +181,15 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VRF.objects.all() queryset = VRF.objects.all()
filterset = VRFFilterSet filterset = VRFFilterSet
@staticmethod
def get_m2m_filter_name(field):
# Override filter names for import & export RouteTargets
if field.name == 'import_targets':
return 'import_target'
if field.name == 'export_targets':
return 'export_target'
return super().get_m2m_filter_name(field)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -1886,9 +1895,9 @@ class ServiceTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'virtual_machine': [vms[0].name, vms[1].name]} params = {'virtual_machine': [vms[0].name, vms[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ipaddress(self): def test_ip_address(self):
ips = IPAddress.objects.all()[:2] ips = IPAddress.objects.all()[:2]
params = {'ipaddress_id': [ips[0].pk, ips[1].pk]} params = {'ip_address_id': [ips[0].pk, ips[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'ipaddress': [str(ips[0].address), str(ips[1].address)]} params = {'ip_address': [str(ips[0].address), str(ips[1].address)]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)

View File

@ -5,6 +5,7 @@ from django.utils.translation import gettext as _
from netbox.filtersets import BaseFilterSet from netbox.filtersets import BaseFilterSet
from users.models import Group, ObjectPermission, Token from users.models import Group, ObjectPermission, Token
from utilities.filters import ContentTypeFilter, MultiValueNumberFilter
__all__ = ( __all__ = (
'GroupFilterSet', 'GroupFilterSet',
@ -118,6 +119,12 @@ class ObjectPermissionFilterSet(BaseFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
)
object_type = ContentTypeFilter(
field_name='object_types'
)
can_view = django_filters.BooleanFilter( can_view = django_filters.BooleanFilter(
method='_check_action' method='_check_action'
) )

View File

@ -15,7 +15,7 @@ User = get_user_model()
class UserTestCase(TestCase, BaseFilterSetTests): class UserTestCase(TestCase, BaseFilterSetTests):
queryset = User.objects.all() queryset = User.objects.all()
filterset = filtersets.UserFilterSet filterset = filtersets.UserFilterSet
ignore_fields = ('config', 'dashboard', 'password') ignore_fields = ('config', 'dashboard', 'password', 'user_permissions')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -110,6 +110,7 @@ class UserTestCase(TestCase, BaseFilterSetTests):
class GroupTestCase(TestCase, BaseFilterSetTests): class GroupTestCase(TestCase, BaseFilterSetTests):
queryset = Group.objects.all() queryset = Group.objects.all()
filterset = filtersets.GroupFilterSet filterset = filtersets.GroupFilterSet
ignore_fields = ('permissions',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -43,6 +43,15 @@ class BaseFilterSetTests:
filterset = None filterset = None
ignore_fields = tuple() ignore_fields = tuple()
@staticmethod
def get_m2m_filter_name(field):
"""
Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
cases may override this method to prescribe deviations for specific fields.
"""
related_model_name = field.related_model._meta.verbose_name
return related_model_name.lower().replace(' ', '_')
def test_id(self): def test_id(self):
""" """
Test filtering for two PKs from a set of >2 objects. Test filtering for two PKs from a set of >2 objects.
@ -94,13 +103,22 @@ class BaseFilterSetTests:
filter_name = model_field.name filter_name = model_field.name
else: else:
filter_name = f'{model_field.name}_id' filter_name = f'{model_field.name}_id'
self.assertIn(filter_name, filterset_fields, f'No filter found for {filter_name}!') self.assertIn(
filter_name,
filterset_fields,
f'No filter defined for {filter_name} ({model_field.name})!'
)
elif type(model_field) is ManyToManyField:
filter_name = self.get_m2m_filter_name(model_field)
filter_name = f'{filter_name}_id'
self.assertIn(
filter_name,
filterset_fields,
f'No filter defined for {filter_name} ({model_field.name})!'
)
# TODO: Many-to-many relationships # TODO: Many-to-many relationships
elif type(model_field) is ManyToManyField:
related_model = model_field.related_model._meta.model_name
filter_name = f'{related_model}_id'
self.assertIn(filter_name, filterset_fields, f'M2M: No filter found for {filter_name}!')
elif type(model_field) is ManyToManyRel: elif type(model_field) is ManyToManyRel:
continue continue
@ -110,14 +128,14 @@ class BaseFilterSetTests:
# Tags # Tags
elif type(model_field) is TaggableManager: elif type(model_field) is TaggableManager:
self.assertIn('tag', filterset_fields, f'No filter found for {model_field.name}!') self.assertIn('tag', filterset_fields, f'No filter defined for {model_field.name}!')
# All other fields # All other fields
else: else:
self.assertIn( self.assertIn(
model_field.name, model_field.name,
filterset_fields, filterset_fields,
f'No filter found for {model_field.name} ({type(model_field)})!' f'No defined found for {model_field.name} ({type(model_field)})!'
) )

View File

@ -158,13 +158,17 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
mode = django_filters.MultipleChoiceFilter( mode = django_filters.MultipleChoiceFilter(
choices=IKEModeChoices choices=IKEModeChoices
) )
proposal_id = MultiValueNumberFilter( ike_proposal_id = MultiValueNumberFilter(
field_name='proposals__id' field_name='proposals__id'
) )
proposal = MultiValueCharFilter( ike_proposal = MultiValueCharFilter(
field_name='proposals__name' field_name='proposals__name'
) )
# TODO: Remove in v4.1
proposal = ike_proposal
proposal_id = ike_proposal_id
class Meta: class Meta:
model = IKEPolicy model = IKEPolicy
fields = ['id', 'name', 'preshared_key', 'description'] fields = ['id', 'name', 'preshared_key', 'description']
@ -205,13 +209,17 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet):
pfs_group = django_filters.MultipleChoiceFilter( pfs_group = django_filters.MultipleChoiceFilter(
choices=DHGroupChoices choices=DHGroupChoices
) )
proposal_id = MultiValueNumberFilter( ipsec_proposal_id = MultiValueNumberFilter(
field_name='proposals__id' field_name='proposals__id'
) )
proposal = MultiValueCharFilter( ipsec_proposal = MultiValueCharFilter(
field_name='proposals__name' field_name='proposals__name'
) )
# TODO: Remove in v4.1
proposal = ipsec_proposal
proposal_id = ipsec_proposal_id
class Meta: class Meta:
model = IPSecPolicy model = IPSecPolicy
fields = ['id', 'name', 'description'] fields = ['id', 'name', 'description']

View File

@ -1,4 +1,3 @@
from django.contrib.contenttypes.models import ContentType
from django.test import TestCase from django.test import TestCase
from dcim.choices import InterfaceTypeChoices from dcim.choices import InterfaceTypeChoices
@ -446,11 +445,11 @@ class IKEPolicyTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'mode': [IKEModeChoices.MAIN]} params = {'mode': [IKEModeChoices.MAIN]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_proposal(self): def test_ike_proposal(self):
proposals = IKEProposal.objects.all()[:2] proposals = IKEProposal.objects.all()[:2]
params = {'proposal_id': [proposals[0].pk, proposals[1].pk]} params = {'ike_proposal_id': [proposals[0].pk, proposals[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'proposal': [proposals[0].name, proposals[1].name]} params = {'ike_proposal': [proposals[0].name, proposals[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
@ -584,11 +583,11 @@ class IPSecPolicyTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'pfs_group': [DHGroupChoices.GROUP_1, DHGroupChoices.GROUP_2]} params = {'pfs_group': [DHGroupChoices.GROUP_1, DHGroupChoices.GROUP_2]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_proposal(self): def test_ipsec_proposal(self):
proposals = IPSecProposal.objects.all()[:2] proposals = IPSecProposal.objects.all()[:2]
params = {'proposal_id': [proposals[0].pk, proposals[1].pk]} params = {'ipsec_proposal_id': [proposals[0].pk, proposals[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'proposal': [proposals[0].name, proposals[1].name]} params = {'ipsec_proposal': [proposals[0].name, proposals[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
@ -710,6 +709,15 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = L2VPN.objects.all() queryset = L2VPN.objects.all()
filterset = L2VPNFilterSet filterset = L2VPNFilterSet
@staticmethod
def get_m2m_filter_name(field):
# Override filter names for import & export RouteTargets
if field.name == 'import_targets':
return 'import_target'
if field.name == 'export_targets':
return 'export_target'
return super().get_m2m_filter_name(field)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):