Add missing filters for reverse many-to-many relationships

This commit is contained in:
Jeremy Stretch 2024-03-07 16:27:58 -05:00
parent 6085e0bb0b
commit b36a70d236
12 changed files with 373 additions and 30 deletions

View File

@ -1184,6 +1184,11 @@ class VirtualDeviceContextFilterSet(NetBoxModelFilterSet, TenancyFilterSet, Prim
queryset=Device.objects.all(),
label='Device model',
)
interface_id = django_filters.ModelMultipleChoiceFilter(
field_name='interfaces',
queryset=Interface.objects.all(),
label='Interface (ID)',
)
status = django_filters.MultipleChoiceFilter(
choices=VirtualDeviceContextStatusChoices
)

View File

@ -5409,15 +5409,22 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
VirtualDeviceContext.objects.bulk_create(vdcs)
interfaces = (
Interface(device=devices[0], name='Interface 1', type='virtual'),
Interface(device=devices[0], name='Interface 2', type='virtual'),
Interface(device=devices[0], name='Interface 1', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[0], name='Interface 2', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[1], name='Interface 3', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[1], name='Interface 4', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[2], name='Interface 5', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[2], name='Interface 6', type=InterfaceTypeChoices.TYPE_VIRTUAL),
)
Interface.objects.bulk_create(interfaces)
interfaces[0].vdcs.set([vdcs[0]])
interfaces[1].vdcs.set([vdcs[1]])
interfaces[2].vdcs.set([vdcs[2]])
interfaces[3].vdcs.set([vdcs[3]])
interfaces[4].vdcs.set([vdcs[4]])
interfaces[5].vdcs.set([vdcs[5]])
addresses = (
ip_addresses = (
IPAddress(assigned_object=interfaces[0], address='10.1.1.1/24'),
IPAddress(assigned_object=interfaces[1], address='10.1.1.2/24'),
IPAddress(assigned_object=None, address='10.1.1.3/24'),
@ -5425,13 +5432,12 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
IPAddress(assigned_object=interfaces[1], address='2001:db8::2/64'),
IPAddress(assigned_object=None, address='2001:db8::3/64'),
)
IPAddress.objects.bulk_create(addresses)
vdcs[0].primary_ip4 = addresses[0]
vdcs[0].primary_ip6 = addresses[3]
IPAddress.objects.bulk_create(ip_addresses)
vdcs[0].primary_ip4 = ip_addresses[0]
vdcs[0].primary_ip6 = ip_addresses[3]
vdcs[0].save()
vdcs[1].primary_ip4 = addresses[1]
vdcs[1].primary_ip6 = addresses[4]
vdcs[1].primary_ip4 = ip_addresses[1]
vdcs[1].primary_ip6 = ip_addresses[4]
vdcs[1].save()
def test_q(self):
@ -5439,8 +5445,11 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_device(self):
params = {'device': ['Device 1', 'Device 2']}
devices = Device.objects.filter(name__in=['Device 1', 'Device 2'])
params = {'device': [devices[0].name, devices[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
params = {'device_id': [devices[0].pk, devices[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_status(self):
params = {'status': ['active']}
@ -5450,10 +5459,10 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_device_id(self):
devices = Device.objects.filter(name__in=['Device 1', 'Device 2'])
params = {'device_id': [devices[0].pk, devices[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_interface(self):
interfaces = Interface.objects.filter(name__in=['Interface 1', 'Interface 3'])
params = {'interface_id': [interfaces[0].pk, interfaces[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_has_primary_ip(self):
params = {'has_primary_ip': True}

View File

@ -1128,7 +1128,93 @@ class ConfigTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Tag.objects.all()
filterset = TagFilterSet
ignore_fields = ('object_types',)
ignore_fields = (
'object_types',
# Reverse relationships (to tagged models) we can ignore
'aggregate',
'asn',
'asnrange',
'cable',
'circuit',
'circuittermination',
'circuittype',
'cluster',
'clustergroup',
'clustertype',
'configtemplate',
'consoleport',
'consoleserverport',
'contact',
'contactassignment',
'contactgroup',
'contactrole',
'datasource',
'device',
'devicebay',
'devicerole',
'devicetype',
'dummymodel', # From dummy_plugin
'eventrule',
'fhrpgroup',
'frontport',
'ikepolicy',
'ikeproposal',
'interface',
'inventoryitem',
'inventoryitemrole',
'ipaddress',
'iprange',
'ipsecpolicy',
'ipsecprofile',
'ipsecproposal',
'journalentry',
'l2vpn',
'l2vpntermination',
'location',
'manufacturer',
'module',
'modulebay',
'moduletype',
'platform',
'powerfeed',
'poweroutlet',
'powerpanel',
'powerport',
'prefix',
'provider',
'provideraccount',
'providernetwork',
'rack',
'rackreservation',
'rackrole',
'rearport',
'region',
'rir',
'role',
'routetarget',
'service',
'servicetemplate',
'site',
'sitegroup',
'tenant',
'tenantgroup',
'tunnel',
'tunnelgroup',
'tunneltermination',
'virtualchassis',
'virtualdevicecontext',
'virtualdisk',
'virtualmachine',
'vlan',
'vlangroup',
'vminterface',
'vrf',
'webhook',
'wirelesslan',
'wirelesslangroup',
'wirelesslink',
)
@classmethod
def setUpTestData(cls):

View File

@ -8,6 +8,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from netaddr.core import AddrFormatError
from circuits.models import Provider
from dcim.models import Device, Interface, Region, Site, SiteGroup
from netbox.filtersets import ChangeLoggedModelFilterSet, OrganizationalModelFilterSet, NetBoxModelFilterSet
from tenancy.filtersets import TenancyFilterSet
@ -101,6 +102,28 @@ class RouteTargetFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
to_field_name='rd',
label=_('Export VRF (RD)'),
)
importing_l2vpn_id = django_filters.ModelMultipleChoiceFilter(
field_name='importing_l2vpns',
queryset=L2VPN.objects.all(),
label=_('Importing L2VPN'),
)
importing_l2vpn = django_filters.ModelMultipleChoiceFilter(
field_name='importing_l2vpns__identifier',
queryset=L2VPN.objects.all(),
to_field_name='identifier',
label=_('Importing L2VPN (identifier)'),
)
exporting_l2vpn_id = django_filters.ModelMultipleChoiceFilter(
field_name='exporting_l2vpns',
queryset=L2VPN.objects.all(),
label=_('Exporting L2VPN'),
)
exporting_l2vpn = django_filters.ModelMultipleChoiceFilter(
field_name='exporting_l2vpns__identifier',
queryset=L2VPN.objects.all(),
to_field_name='identifier',
label=_('Exporting L2VPN (identifier)'),
)
def search(self, queryset, name, value):
if not value.strip():
@ -214,6 +237,17 @@ class ASNFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
to_field_name='slug',
label=_('Site (slug)'),
)
provider_id = django_filters.ModelMultipleChoiceFilter(
field_name='providers',
queryset=Provider.objects.all(),
label=_('Provider (ID)'),
)
provider = django_filters.ModelMultipleChoiceFilter(
field_name='providers__slug',
queryset=Provider.objects.all(),
to_field_name='slug',
label=_('Provider (slug)'),
)
class Meta:
model = ASN
@ -628,6 +662,11 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
role = django_filters.MultipleChoiceFilter(
choices=IPAddressRoleChoices
)
service_id = django_filters.ModelMultipleChoiceFilter(
field_name='services',
queryset=Service.objects.all(),
label=_('Service (ID)'),
)
class Meta:
model = IPAddress

View File

@ -2,6 +2,7 @@ from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from netaddr import IPNetwork
from circuits.models import Provider
from dcim.choices import InterfaceTypeChoices
from dcim.models import Device, DeviceRole, DeviceType, Interface, Location, Manufacturer, Rack, Region, Site, SiteGroup
from ipam.choices import *
@ -10,6 +11,8 @@ from ipam.models import *
from tenancy.models import Tenant, TenantGroup
from utilities.testing import ChangeLoggedFilterSetTests, create_test_device, create_test_virtualmachine
from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface
from vpn.choices import L2VPNTypeChoices
from vpn.models import L2VPN
class ASNRangeTestCase(TestCase, ChangeLoggedFilterSetTests):
@ -110,13 +113,6 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
]
RIR.objects.bulk_create(rirs)
sites = [
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3')
]
Site.objects.bulk_create(sites)
tenants = [
Tenant(name='Tenant 1', slug='tenant-1'),
Tenant(name='Tenant 2', slug='tenant-2'),
@ -136,6 +132,12 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
)
ASN.objects.bulk_create(asns)
sites = [
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3')
]
Site.objects.bulk_create(sites)
asns[0].sites.set([sites[0]])
asns[1].sites.set([sites[1]])
asns[2].sites.set([sites[2]])
@ -143,6 +145,16 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
asns[4].sites.set([sites[1]])
asns[5].sites.set([sites[2]])
providers = (
Provider(name='Provider 1', slug='provider-1'),
Provider(name='Provider 2', slug='provider-2'),
Provider(name='Provider 3', slug='provider-3'),
)
Provider.objects.bulk_create(providers)
providers[0].asns.add(asns[0])
providers[1].asns.add(asns[1])
providers[2].asns.add(asns[2])
def test_q(self):
params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -176,6 +188,11 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_provider(self):
providers = Provider.objects.all()[:2]
params = {'provider_id': [providers[0].pk, providers[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VRF.objects.all()
@ -188,7 +205,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
return 'import_target'
if field.name == 'export_targets':
return 'export_target'
return super().get_m2m_filter_name(field)
return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
@classmethod
def setUpTestData(cls):
@ -286,6 +303,19 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = RouteTarget.objects.all()
filterset = RouteTargetFilterSet
@staticmethod
def get_m2m_filter_name(field):
# Override filter names for import & export VRFs and L2VPNs
if field.name == 'importing_vrfs':
return 'importing_vrf'
if field.name == 'exporting_vrfs':
return 'exporting_vrf'
if field.name == 'importing_l2vpns':
return 'importing_l2vpn'
if field.name == 'exporting_l2vpns':
return 'exporting_l2vpn'
return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
@classmethod
def setUpTestData(cls):
@ -331,6 +361,17 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
vrfs[1].import_targets.add(route_targets[4], route_targets[5])
vrfs[1].export_targets.add(route_targets[6], route_targets[7])
l2vpns = (
L2VPN(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=100),
L2VPN(name='L2VPN 2', slug='l2vpn-2', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=200),
L2VPN(name='L2VPN 3', slug='l2vpn-3', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=300),
)
L2VPN.objects.bulk_create(l2vpns)
l2vpns[0].import_targets.add(route_targets[0], route_targets[1])
l2vpns[0].export_targets.add(route_targets[2], route_targets[3])
l2vpns[1].import_targets.add(route_targets[4], route_targets[5])
l2vpns[1].export_targets.add(route_targets[6], route_targets[7])
def test_q(self):
params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -353,6 +394,20 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'exporting_vrf': [vrfs[0].rd, vrfs[1].rd]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_importing_l2vpn(self):
l2vpns = L2VPN.objects.all()[:2]
params = {'importing_l2vpn_id': [l2vpns[0].pk, l2vpns[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'importing_l2vpn': [l2vpns[0].identifier, l2vpns[1].identifier]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_exporting_l2vpn(self):
l2vpns = L2VPN.objects.all()[:2]
params = {'exporting_l2vpn_id': [l2vpns[0].pk, l2vpns[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'exporting_l2vpn': [l2vpns[0].identifier, l2vpns[1].identifier]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_tenant(self):
tenants = Tenant.objects.all()[:2]
params = {'tenant_id': [tenants[0].pk, tenants[1].pk]}
@ -1102,6 +1157,16 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests):
)
IPAddress.objects.bulk_create(ipaddresses)
services = (
Service(name='Service 1', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
Service(name='Service 2', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
Service(name='Service 3', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
)
Service.objects.bulk_create(services)
services[0].ipaddresses.add(ipaddresses[0])
services[1].ipaddresses.add(ipaddresses[1])
services[2].ipaddresses.add(ipaddresses[2])
def test_q(self):
params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -1241,6 +1306,11 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'tenant_group': [tenant_groups[0].slug, tenant_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_service(self):
services = Service.objects.all()[:2]
params = {'service_id': [services[0].pk, services[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class FHRPGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = FHRPGroup.objects.all()
@ -1485,6 +1555,7 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
class VLANTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VLAN.objects.all()
filterset = VLANFilterSet
ignore_fields = ('interfaces_as_tagged', 'vminterfaces_as_tagged')
@classmethod
def setUpTestData(cls):

View File

@ -20,6 +20,16 @@ class GroupFilterSet(BaseFilterSet):
method='search',
label=_('Search'),
)
user_id = django_filters.ModelMultipleChoiceFilter(
field_name='user',
queryset=get_user_model().objects.all(),
label=_('User (ID)'),
)
permission_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_permissions',
queryset=ObjectPermission.objects.all(),
label=_('Permission (ID)'),
)
class Meta:
model = Group
@ -47,6 +57,11 @@ class UserFilterSet(BaseFilterSet):
to_field_name='name',
label=_('Group (name)'),
)
permission_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_permissions',
queryset=ObjectPermission.objects.all(),
label=_('Permission (ID)'),
)
class Meta:
model = get_user_model()

View File

@ -67,6 +67,16 @@ class UserTestCase(TestCase, BaseFilterSetTests):
users[1].groups.set([groups[1]])
users[2].groups.set([groups[2]])
object_permissions = (
ObjectPermission(name='Permission 1', actions=['add']),
ObjectPermission(name='Permission 2', actions=['change']),
ObjectPermission(name='Permission 3', actions=['delete']),
)
ObjectPermission.objects.bulk_create(object_permissions)
object_permissions[0].users.add(users[0])
object_permissions[1].users.add(users[1])
object_permissions[2].users.add(users[2])
def test_q(self):
params = {'q': 'user1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -106,6 +116,11 @@ class UserTestCase(TestCase, BaseFilterSetTests):
params = {'group': [groups[0].name, groups[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_permission(self):
object_permissions = ObjectPermission.objects.all()[:2]
params = {'permission_id': [object_permissions[0].pk, object_permissions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class GroupTestCase(TestCase, BaseFilterSetTests):
queryset = Group.objects.all()
@ -122,6 +137,26 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
)
Group.objects.bulk_create(groups)
users = (
User(username='User 1'),
User(username='User 2'),
User(username='User 3'),
)
User.objects.bulk_create(users)
users[0].groups.set([groups[0]])
users[1].groups.set([groups[1]])
users[2].groups.set([groups[2]])
object_permissions = (
ObjectPermission(name='Permission 1', actions=['add']),
ObjectPermission(name='Permission 2', actions=['change']),
ObjectPermission(name='Permission 3', actions=['delete']),
)
ObjectPermission.objects.bulk_create(object_permissions)
object_permissions[0].groups.add(groups[0])
object_permissions[1].groups.add(groups[1])
object_permissions[2].groups.add(groups[2])
def test_q(self):
params = {'q': 'group 1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -130,6 +165,16 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
params = {'name': ['Group 1', 'Group 2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_user(self):
users = User.objects.all()[:2]
params = {'user_id': [users[0].pk, users[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_permission(self):
object_permissions = ObjectPermission.objects.all()[:2]
params = {'permission_id': [object_permissions[0].pk, object_permissions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class ObjectPermissionTestCase(TestCase, BaseFilterSetTests):
queryset = ObjectPermission.objects.all()

View File

@ -109,7 +109,7 @@ class BaseFilterSetTests:
f'No filter defined for {filter_name} ({model_field.name})!'
)
elif type(model_field) is ManyToManyField:
elif type(model_field) in (ManyToManyField, ManyToManyRel):
filter_name = self.get_m2m_filter_name(model_field)
filter_name = f'{filter_name}_id'
self.assertIn(
@ -118,10 +118,6 @@ class BaseFilterSetTests:
f'No filter defined for {filter_name} ({model_field.name})!'
)
# TODO: Many-to-many relationships
elif type(model_field) is ManyToManyRel:
continue
# TODO: Generic relationships
elif type(model_field) in (GenericForeignKey, GenericRelation):
continue

View File

@ -124,6 +124,17 @@ class TunnelTerminationFilterSet(NetBoxModelFilterSet):
class IKEProposalFilterSet(NetBoxModelFilterSet):
ike_policy_id = django_filters.ModelMultipleChoiceFilter(
field_name='ike_policies',
queryset=IKEPolicy.objects.all(),
label=_('IKE policy (ID)'),
)
ike_policy = django_filters.ModelMultipleChoiceFilter(
field_name='ike_policies__name',
queryset=IKEPolicy.objects.all(),
to_field_name='name',
label=_('IKE policy (name)'),
)
authentication_method = django_filters.MultipleChoiceFilter(
choices=AuthenticationMethodChoices
)
@ -184,6 +195,17 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
class IPSecProposalFilterSet(NetBoxModelFilterSet):
ipsec_policy_id = django_filters.ModelMultipleChoiceFilter(
field_name='ipsec_policies',
queryset=IPSecPolicy.objects.all(),
label=_('IPSec policy (ID)'),
)
ipsec_policy = django_filters.ModelMultipleChoiceFilter(
field_name='ipsec_policies__name',
queryset=IPSecPolicy.objects.all(),
to_field_name='name',
label=_('IPSec policy (name)'),
)
encryption_algorithm = django_filters.MultipleChoiceFilter(
choices=EncryptionAlgorithmChoices
)

View File

@ -330,6 +330,16 @@ class IKEProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
)
IKEProposal.objects.bulk_create(ike_proposals)
ike_policies = (
IKEPolicy(name='IKE Policy 1'),
IKEPolicy(name='IKE Policy 2'),
IKEPolicy(name='IKE Policy 3'),
)
IKEPolicy.objects.bulk_create(ike_policies)
ike_policies[0].proposals.add(ike_proposals[0])
ike_policies[1].proposals.add(ike_proposals[1])
ike_policies[2].proposals.add(ike_proposals[2])
def test_q(self):
params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -342,6 +352,13 @@ class IKEProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ike_policy(self):
ike_policies = IKEPolicy.objects.all()[:2]
params = {'ike_policy_id': [ike_policies[0].pk, ike_policies[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'ike_policy': [ike_policies[0].name, ike_policies[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_authentication_method(self):
params = {'authentication_method': [
AuthenticationMethodChoices.PRESHARED_KEYS, AuthenticationMethodChoices.CERTIFICATES
@ -487,6 +504,16 @@ class IPSecProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
)
IPSecProposal.objects.bulk_create(ipsec_proposals)
ipsec_policies = (
IPSecPolicy(name='IPSec Policy 1'),
IPSecPolicy(name='IPSec Policy 2'),
IPSecPolicy(name='IPSec Policy 3'),
)
IPSecPolicy.objects.bulk_create(ipsec_policies)
ipsec_policies[0].proposals.add(ipsec_proposals[0])
ipsec_policies[1].proposals.add(ipsec_proposals[1])
ipsec_policies[2].proposals.add(ipsec_proposals[2])
def test_q(self):
params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -499,6 +526,13 @@ class IPSecProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ipsec_policy(self):
ipsec_policies = IPSecPolicy.objects.all()[:2]
params = {'ipsec_policy_id': [ipsec_policies[0].pk, ipsec_policies[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'ipsec_policy': [ipsec_policies[0].name, ipsec_policies[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_encryption_algorithm(self):
params = {'encryption_algorithm': [
EncryptionAlgorithmChoices.ENCRYPTION_AES128_CBC, EncryptionAlgorithmChoices.ENCRYPTION_AES192_CBC
@ -716,7 +750,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
return 'import_target'
if field.name == 'export_targets':
return 'export_target'
return super().get_m2m_filter_name(field)
return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
@classmethod
def setUpTestData(cls):

View File

@ -2,6 +2,7 @@ import django_filters
from django.db.models import Q
from dcim.choices import LinkStatusChoices
from dcim.models import Interface
from ipam.models import VLAN
from netbox.filtersets import OrganizationalModelFilterSet, NetBoxModelFilterSet
from tenancy.filtersets import TenancyFilterSet
@ -60,6 +61,10 @@ class WirelessLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
vlan_id = django_filters.ModelMultipleChoiceFilter(
queryset=VLAN.objects.all()
)
interface_id = django_filters.ModelMultipleChoiceFilter(
queryset=Interface.objects.all(),
field_name='interfaces'
)
auth_type = django_filters.MultipleChoiceFilter(
choices=WirelessAuthTypeChoices
)

View File

@ -153,6 +153,17 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
)
WirelessLAN.objects.bulk_create(wireless_lans)
device = create_test_device('Device 1')
interfaces = (
Interface(device=device, name='Interface 1', type=InterfaceTypeChoices.TYPE_80211N),
Interface(device=device, name='Interface 2', type=InterfaceTypeChoices.TYPE_80211N),
Interface(device=device, name='Interface 3', type=InterfaceTypeChoices.TYPE_80211N),
)
Interface.objects.bulk_create(interfaces)
interfaces[0].wireless_lans.add(wireless_lans[0])
interfaces[1].wireless_lans.add(wireless_lans[1])
interfaces[2].wireless_lans.add(wireless_lans[2])
def test_q(self):
params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -200,6 +211,11 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'tenant': [tenants[0].slug, tenants[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_interface(self):
interfaces = Interface.objects.all()[:2]
params = {'interface_id': [interfaces[0].pk, interfaces[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class WirelessLinkTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = WirelessLink.objects.all()