Add missing filters for reverse many-to-many relationships

This commit is contained in:
Jeremy Stretch 2024-03-07 16:27:58 -05:00
parent 6085e0bb0b
commit b36a70d236
12 changed files with 373 additions and 30 deletions

View File

@ -1184,6 +1184,11 @@ class VirtualDeviceContextFilterSet(NetBoxModelFilterSet, TenancyFilterSet, Prim
queryset=Device.objects.all(), queryset=Device.objects.all(),
label='Device model', label='Device model',
) )
interface_id = django_filters.ModelMultipleChoiceFilter(
field_name='interfaces',
queryset=Interface.objects.all(),
label='Interface (ID)',
)
status = django_filters.MultipleChoiceFilter( status = django_filters.MultipleChoiceFilter(
choices=VirtualDeviceContextStatusChoices choices=VirtualDeviceContextStatusChoices
) )

View File

@ -5409,15 +5409,22 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
VirtualDeviceContext.objects.bulk_create(vdcs) VirtualDeviceContext.objects.bulk_create(vdcs)
interfaces = ( interfaces = (
Interface(device=devices[0], name='Interface 1', type='virtual'), Interface(device=devices[0], name='Interface 1', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[0], name='Interface 2', type='virtual'), Interface(device=devices[0], name='Interface 2', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[1], name='Interface 3', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[1], name='Interface 4', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[2], name='Interface 5', type=InterfaceTypeChoices.TYPE_VIRTUAL),
Interface(device=devices[2], name='Interface 6', type=InterfaceTypeChoices.TYPE_VIRTUAL),
) )
Interface.objects.bulk_create(interfaces) Interface.objects.bulk_create(interfaces)
interfaces[0].vdcs.set([vdcs[0]]) interfaces[0].vdcs.set([vdcs[0]])
interfaces[1].vdcs.set([vdcs[1]]) interfaces[1].vdcs.set([vdcs[1]])
interfaces[2].vdcs.set([vdcs[2]])
interfaces[3].vdcs.set([vdcs[3]])
interfaces[4].vdcs.set([vdcs[4]])
interfaces[5].vdcs.set([vdcs[5]])
addresses = ( ip_addresses = (
IPAddress(assigned_object=interfaces[0], address='10.1.1.1/24'), IPAddress(assigned_object=interfaces[0], address='10.1.1.1/24'),
IPAddress(assigned_object=interfaces[1], address='10.1.1.2/24'), IPAddress(assigned_object=interfaces[1], address='10.1.1.2/24'),
IPAddress(assigned_object=None, address='10.1.1.3/24'), IPAddress(assigned_object=None, address='10.1.1.3/24'),
@ -5425,13 +5432,12 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
IPAddress(assigned_object=interfaces[1], address='2001:db8::2/64'), IPAddress(assigned_object=interfaces[1], address='2001:db8::2/64'),
IPAddress(assigned_object=None, address='2001:db8::3/64'), IPAddress(assigned_object=None, address='2001:db8::3/64'),
) )
IPAddress.objects.bulk_create(addresses) IPAddress.objects.bulk_create(ip_addresses)
vdcs[0].primary_ip4 = ip_addresses[0]
vdcs[0].primary_ip4 = addresses[0] vdcs[0].primary_ip6 = ip_addresses[3]
vdcs[0].primary_ip6 = addresses[3]
vdcs[0].save() vdcs[0].save()
vdcs[1].primary_ip4 = addresses[1] vdcs[1].primary_ip4 = ip_addresses[1]
vdcs[1].primary_ip6 = addresses[4] vdcs[1].primary_ip6 = ip_addresses[4]
vdcs[1].save() vdcs[1].save()
def test_q(self): def test_q(self):
@ -5439,8 +5445,11 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_device(self): def test_device(self):
params = {'device': ['Device 1', 'Device 2']} devices = Device.objects.filter(name__in=['Device 1', 'Device 2'])
params = {'device': [devices[0].name, devices[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
params = {'device_id': [devices[0].pk, devices[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_status(self): def test_status(self):
params = {'status': ['active']} params = {'status': ['active']}
@ -5450,10 +5459,10 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']} params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_device_id(self): def test_interface(self):
devices = Device.objects.filter(name__in=['Device 1', 'Device 2']) interfaces = Interface.objects.filter(name__in=['Interface 1', 'Interface 3'])
params = {'device_id': [devices[0].pk, devices[1].pk]} params = {'interface_id': [interfaces[0].pk, interfaces[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_has_primary_ip(self): def test_has_primary_ip(self):
params = {'has_primary_ip': True} params = {'has_primary_ip': True}

View File

@ -1128,7 +1128,93 @@ class ConfigTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
class TagTestCase(TestCase, ChangeLoggedFilterSetTests): class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Tag.objects.all() queryset = Tag.objects.all()
filterset = TagFilterSet filterset = TagFilterSet
ignore_fields = ('object_types',) ignore_fields = (
'object_types',
# Reverse relationships (to tagged models) we can ignore
'aggregate',
'asn',
'asnrange',
'cable',
'circuit',
'circuittermination',
'circuittype',
'cluster',
'clustergroup',
'clustertype',
'configtemplate',
'consoleport',
'consoleserverport',
'contact',
'contactassignment',
'contactgroup',
'contactrole',
'datasource',
'device',
'devicebay',
'devicerole',
'devicetype',
'dummymodel', # From dummy_plugin
'eventrule',
'fhrpgroup',
'frontport',
'ikepolicy',
'ikeproposal',
'interface',
'inventoryitem',
'inventoryitemrole',
'ipaddress',
'iprange',
'ipsecpolicy',
'ipsecprofile',
'ipsecproposal',
'journalentry',
'l2vpn',
'l2vpntermination',
'location',
'manufacturer',
'module',
'modulebay',
'moduletype',
'platform',
'powerfeed',
'poweroutlet',
'powerpanel',
'powerport',
'prefix',
'provider',
'provideraccount',
'providernetwork',
'rack',
'rackreservation',
'rackrole',
'rearport',
'region',
'rir',
'role',
'routetarget',
'service',
'servicetemplate',
'site',
'sitegroup',
'tenant',
'tenantgroup',
'tunnel',
'tunnelgroup',
'tunneltermination',
'virtualchassis',
'virtualdevicecontext',
'virtualdisk',
'virtualmachine',
'vlan',
'vlangroup',
'vminterface',
'vrf',
'webhook',
'wirelesslan',
'wirelesslangroup',
'wirelesslink',
)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -8,6 +8,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field from drf_spectacular.utils import extend_schema_field
from netaddr.core import AddrFormatError from netaddr.core import AddrFormatError
from circuits.models import Provider
from dcim.models import Device, Interface, Region, Site, SiteGroup from dcim.models import Device, Interface, Region, Site, SiteGroup
from netbox.filtersets import ChangeLoggedModelFilterSet, OrganizationalModelFilterSet, NetBoxModelFilterSet from netbox.filtersets import ChangeLoggedModelFilterSet, OrganizationalModelFilterSet, NetBoxModelFilterSet
from tenancy.filtersets import TenancyFilterSet from tenancy.filtersets import TenancyFilterSet
@ -101,6 +102,28 @@ class RouteTargetFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
to_field_name='rd', to_field_name='rd',
label=_('Export VRF (RD)'), label=_('Export VRF (RD)'),
) )
importing_l2vpn_id = django_filters.ModelMultipleChoiceFilter(
field_name='importing_l2vpns',
queryset=L2VPN.objects.all(),
label=_('Importing L2VPN'),
)
importing_l2vpn = django_filters.ModelMultipleChoiceFilter(
field_name='importing_l2vpns__identifier',
queryset=L2VPN.objects.all(),
to_field_name='identifier',
label=_('Importing L2VPN (identifier)'),
)
exporting_l2vpn_id = django_filters.ModelMultipleChoiceFilter(
field_name='exporting_l2vpns',
queryset=L2VPN.objects.all(),
label=_('Exporting L2VPN'),
)
exporting_l2vpn = django_filters.ModelMultipleChoiceFilter(
field_name='exporting_l2vpns__identifier',
queryset=L2VPN.objects.all(),
to_field_name='identifier',
label=_('Exporting L2VPN (identifier)'),
)
def search(self, queryset, name, value): def search(self, queryset, name, value):
if not value.strip(): if not value.strip():
@ -214,6 +237,17 @@ class ASNFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
to_field_name='slug', to_field_name='slug',
label=_('Site (slug)'), label=_('Site (slug)'),
) )
provider_id = django_filters.ModelMultipleChoiceFilter(
field_name='providers',
queryset=Provider.objects.all(),
label=_('Provider (ID)'),
)
provider = django_filters.ModelMultipleChoiceFilter(
field_name='providers__slug',
queryset=Provider.objects.all(),
to_field_name='slug',
label=_('Provider (slug)'),
)
class Meta: class Meta:
model = ASN model = ASN
@ -628,6 +662,11 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
role = django_filters.MultipleChoiceFilter( role = django_filters.MultipleChoiceFilter(
choices=IPAddressRoleChoices choices=IPAddressRoleChoices
) )
service_id = django_filters.ModelMultipleChoiceFilter(
field_name='services',
queryset=Service.objects.all(),
label=_('Service (ID)'),
)
class Meta: class Meta:
model = IPAddress model = IPAddress

View File

@ -2,6 +2,7 @@ from django.contrib.contenttypes.models import ContentType
from django.test import TestCase from django.test import TestCase
from netaddr import IPNetwork from netaddr import IPNetwork
from circuits.models import Provider
from dcim.choices import InterfaceTypeChoices from dcim.choices import InterfaceTypeChoices
from dcim.models import Device, DeviceRole, DeviceType, Interface, Location, Manufacturer, Rack, Region, Site, SiteGroup from dcim.models import Device, DeviceRole, DeviceType, Interface, Location, Manufacturer, Rack, Region, Site, SiteGroup
from ipam.choices import * from ipam.choices import *
@ -10,6 +11,8 @@ from ipam.models import *
from tenancy.models import Tenant, TenantGroup from tenancy.models import Tenant, TenantGroup
from utilities.testing import ChangeLoggedFilterSetTests, create_test_device, create_test_virtualmachine from utilities.testing import ChangeLoggedFilterSetTests, create_test_device, create_test_virtualmachine
from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface
from vpn.choices import L2VPNTypeChoices
from vpn.models import L2VPN
class ASNRangeTestCase(TestCase, ChangeLoggedFilterSetTests): class ASNRangeTestCase(TestCase, ChangeLoggedFilterSetTests):
@ -110,13 +113,6 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
] ]
RIR.objects.bulk_create(rirs) RIR.objects.bulk_create(rirs)
sites = [
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3')
]
Site.objects.bulk_create(sites)
tenants = [ tenants = [
Tenant(name='Tenant 1', slug='tenant-1'), Tenant(name='Tenant 1', slug='tenant-1'),
Tenant(name='Tenant 2', slug='tenant-2'), Tenant(name='Tenant 2', slug='tenant-2'),
@ -136,6 +132,12 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
) )
ASN.objects.bulk_create(asns) ASN.objects.bulk_create(asns)
sites = [
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3')
]
Site.objects.bulk_create(sites)
asns[0].sites.set([sites[0]]) asns[0].sites.set([sites[0]])
asns[1].sites.set([sites[1]]) asns[1].sites.set([sites[1]])
asns[2].sites.set([sites[2]]) asns[2].sites.set([sites[2]])
@ -143,6 +145,16 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
asns[4].sites.set([sites[1]]) asns[4].sites.set([sites[1]])
asns[5].sites.set([sites[2]]) asns[5].sites.set([sites[2]])
providers = (
Provider(name='Provider 1', slug='provider-1'),
Provider(name='Provider 2', slug='provider-2'),
Provider(name='Provider 3', slug='provider-3'),
)
Provider.objects.bulk_create(providers)
providers[0].asns.add(asns[0])
providers[1].asns.add(asns[1])
providers[2].asns.add(asns[2])
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -176,6 +188,11 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']} params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_provider(self):
providers = Provider.objects.all()[:2]
params = {'provider_id': [providers[0].pk, providers[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class VRFTestCase(TestCase, ChangeLoggedFilterSetTests): class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VRF.objects.all() queryset = VRF.objects.all()
@ -188,7 +205,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
return 'import_target' return 'import_target'
if field.name == 'export_targets': if field.name == 'export_targets':
return 'export_target' return 'export_target'
return super().get_m2m_filter_name(field) return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -286,6 +303,19 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = RouteTarget.objects.all() queryset = RouteTarget.objects.all()
filterset = RouteTargetFilterSet filterset = RouteTargetFilterSet
@staticmethod
def get_m2m_filter_name(field):
# Override filter names for import & export VRFs and L2VPNs
if field.name == 'importing_vrfs':
return 'importing_vrf'
if field.name == 'exporting_vrfs':
return 'exporting_vrf'
if field.name == 'importing_l2vpns':
return 'importing_l2vpn'
if field.name == 'exporting_l2vpns':
return 'exporting_l2vpn'
return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -331,6 +361,17 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
vrfs[1].import_targets.add(route_targets[4], route_targets[5]) vrfs[1].import_targets.add(route_targets[4], route_targets[5])
vrfs[1].export_targets.add(route_targets[6], route_targets[7]) vrfs[1].export_targets.add(route_targets[6], route_targets[7])
l2vpns = (
L2VPN(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=100),
L2VPN(name='L2VPN 2', slug='l2vpn-2', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=200),
L2VPN(name='L2VPN 3', slug='l2vpn-3', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=300),
)
L2VPN.objects.bulk_create(l2vpns)
l2vpns[0].import_targets.add(route_targets[0], route_targets[1])
l2vpns[0].export_targets.add(route_targets[2], route_targets[3])
l2vpns[1].import_targets.add(route_targets[4], route_targets[5])
l2vpns[1].export_targets.add(route_targets[6], route_targets[7])
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -353,6 +394,20 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'exporting_vrf': [vrfs[0].rd, vrfs[1].rd]} params = {'exporting_vrf': [vrfs[0].rd, vrfs[1].rd]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_importing_l2vpn(self):
l2vpns = L2VPN.objects.all()[:2]
params = {'importing_l2vpn_id': [l2vpns[0].pk, l2vpns[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'importing_l2vpn': [l2vpns[0].identifier, l2vpns[1].identifier]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_exporting_l2vpn(self):
l2vpns = L2VPN.objects.all()[:2]
params = {'exporting_l2vpn_id': [l2vpns[0].pk, l2vpns[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'exporting_l2vpn': [l2vpns[0].identifier, l2vpns[1].identifier]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_tenant(self): def test_tenant(self):
tenants = Tenant.objects.all()[:2] tenants = Tenant.objects.all()[:2]
params = {'tenant_id': [tenants[0].pk, tenants[1].pk]} params = {'tenant_id': [tenants[0].pk, tenants[1].pk]}
@ -1102,6 +1157,16 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests):
) )
IPAddress.objects.bulk_create(ipaddresses) IPAddress.objects.bulk_create(ipaddresses)
services = (
Service(name='Service 1', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
Service(name='Service 2', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
Service(name='Service 3', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
)
Service.objects.bulk_create(services)
services[0].ipaddresses.add(ipaddresses[0])
services[1].ipaddresses.add(ipaddresses[1])
services[2].ipaddresses.add(ipaddresses[2])
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -1241,6 +1306,11 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'tenant_group': [tenant_groups[0].slug, tenant_groups[1].slug]} params = {'tenant_group': [tenant_groups[0].slug, tenant_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_service(self):
services = Service.objects.all()[:2]
params = {'service_id': [services[0].pk, services[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class FHRPGroupTestCase(TestCase, ChangeLoggedFilterSetTests): class FHRPGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = FHRPGroup.objects.all() queryset = FHRPGroup.objects.all()
@ -1485,6 +1555,7 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
class VLANTestCase(TestCase, ChangeLoggedFilterSetTests): class VLANTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VLAN.objects.all() queryset = VLAN.objects.all()
filterset = VLANFilterSet filterset = VLANFilterSet
ignore_fields = ('interfaces_as_tagged', 'vminterfaces_as_tagged')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -20,6 +20,16 @@ class GroupFilterSet(BaseFilterSet):
method='search', method='search',
label=_('Search'), label=_('Search'),
) )
user_id = django_filters.ModelMultipleChoiceFilter(
field_name='user',
queryset=get_user_model().objects.all(),
label=_('User (ID)'),
)
permission_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_permissions',
queryset=ObjectPermission.objects.all(),
label=_('Permission (ID)'),
)
class Meta: class Meta:
model = Group model = Group
@ -47,6 +57,11 @@ class UserFilterSet(BaseFilterSet):
to_field_name='name', to_field_name='name',
label=_('Group (name)'), label=_('Group (name)'),
) )
permission_id = django_filters.ModelMultipleChoiceFilter(
field_name='object_permissions',
queryset=ObjectPermission.objects.all(),
label=_('Permission (ID)'),
)
class Meta: class Meta:
model = get_user_model() model = get_user_model()

View File

@ -67,6 +67,16 @@ class UserTestCase(TestCase, BaseFilterSetTests):
users[1].groups.set([groups[1]]) users[1].groups.set([groups[1]])
users[2].groups.set([groups[2]]) users[2].groups.set([groups[2]])
object_permissions = (
ObjectPermission(name='Permission 1', actions=['add']),
ObjectPermission(name='Permission 2', actions=['change']),
ObjectPermission(name='Permission 3', actions=['delete']),
)
ObjectPermission.objects.bulk_create(object_permissions)
object_permissions[0].users.add(users[0])
object_permissions[1].users.add(users[1])
object_permissions[2].users.add(users[2])
def test_q(self): def test_q(self):
params = {'q': 'user1'} params = {'q': 'user1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -106,6 +116,11 @@ class UserTestCase(TestCase, BaseFilterSetTests):
params = {'group': [groups[0].name, groups[1].name]} params = {'group': [groups[0].name, groups[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_permission(self):
object_permissions = ObjectPermission.objects.all()[:2]
params = {'permission_id': [object_permissions[0].pk, object_permissions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class GroupTestCase(TestCase, BaseFilterSetTests): class GroupTestCase(TestCase, BaseFilterSetTests):
queryset = Group.objects.all() queryset = Group.objects.all()
@ -122,6 +137,26 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
) )
Group.objects.bulk_create(groups) Group.objects.bulk_create(groups)
users = (
User(username='User 1'),
User(username='User 2'),
User(username='User 3'),
)
User.objects.bulk_create(users)
users[0].groups.set([groups[0]])
users[1].groups.set([groups[1]])
users[2].groups.set([groups[2]])
object_permissions = (
ObjectPermission(name='Permission 1', actions=['add']),
ObjectPermission(name='Permission 2', actions=['change']),
ObjectPermission(name='Permission 3', actions=['delete']),
)
ObjectPermission.objects.bulk_create(object_permissions)
object_permissions[0].groups.add(groups[0])
object_permissions[1].groups.add(groups[1])
object_permissions[2].groups.add(groups[2])
def test_q(self): def test_q(self):
params = {'q': 'group 1'} params = {'q': 'group 1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -130,6 +165,16 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
params = {'name': ['Group 1', 'Group 2']} params = {'name': ['Group 1', 'Group 2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_user(self):
users = User.objects.all()[:2]
params = {'user_id': [users[0].pk, users[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_permission(self):
object_permissions = ObjectPermission.objects.all()[:2]
params = {'permission_id': [object_permissions[0].pk, object_permissions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class ObjectPermissionTestCase(TestCase, BaseFilterSetTests): class ObjectPermissionTestCase(TestCase, BaseFilterSetTests):
queryset = ObjectPermission.objects.all() queryset = ObjectPermission.objects.all()

View File

@ -109,7 +109,7 @@ class BaseFilterSetTests:
f'No filter defined for {filter_name} ({model_field.name})!' f'No filter defined for {filter_name} ({model_field.name})!'
) )
elif type(model_field) is ManyToManyField: elif type(model_field) in (ManyToManyField, ManyToManyRel):
filter_name = self.get_m2m_filter_name(model_field) filter_name = self.get_m2m_filter_name(model_field)
filter_name = f'{filter_name}_id' filter_name = f'{filter_name}_id'
self.assertIn( self.assertIn(
@ -118,10 +118,6 @@ class BaseFilterSetTests:
f'No filter defined for {filter_name} ({model_field.name})!' f'No filter defined for {filter_name} ({model_field.name})!'
) )
# TODO: Many-to-many relationships
elif type(model_field) is ManyToManyRel:
continue
# TODO: Generic relationships # TODO: Generic relationships
elif type(model_field) in (GenericForeignKey, GenericRelation): elif type(model_field) in (GenericForeignKey, GenericRelation):
continue continue

View File

@ -124,6 +124,17 @@ class TunnelTerminationFilterSet(NetBoxModelFilterSet):
class IKEProposalFilterSet(NetBoxModelFilterSet): class IKEProposalFilterSet(NetBoxModelFilterSet):
ike_policy_id = django_filters.ModelMultipleChoiceFilter(
field_name='ike_policies',
queryset=IKEPolicy.objects.all(),
label=_('IKE policy (ID)'),
)
ike_policy = django_filters.ModelMultipleChoiceFilter(
field_name='ike_policies__name',
queryset=IKEPolicy.objects.all(),
to_field_name='name',
label=_('IKE policy (name)'),
)
authentication_method = django_filters.MultipleChoiceFilter( authentication_method = django_filters.MultipleChoiceFilter(
choices=AuthenticationMethodChoices choices=AuthenticationMethodChoices
) )
@ -184,6 +195,17 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
class IPSecProposalFilterSet(NetBoxModelFilterSet): class IPSecProposalFilterSet(NetBoxModelFilterSet):
ipsec_policy_id = django_filters.ModelMultipleChoiceFilter(
field_name='ipsec_policies',
queryset=IPSecPolicy.objects.all(),
label=_('IPSec policy (ID)'),
)
ipsec_policy = django_filters.ModelMultipleChoiceFilter(
field_name='ipsec_policies__name',
queryset=IPSecPolicy.objects.all(),
to_field_name='name',
label=_('IPSec policy (name)'),
)
encryption_algorithm = django_filters.MultipleChoiceFilter( encryption_algorithm = django_filters.MultipleChoiceFilter(
choices=EncryptionAlgorithmChoices choices=EncryptionAlgorithmChoices
) )

View File

@ -330,6 +330,16 @@ class IKEProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
) )
IKEProposal.objects.bulk_create(ike_proposals) IKEProposal.objects.bulk_create(ike_proposals)
ike_policies = (
IKEPolicy(name='IKE Policy 1'),
IKEPolicy(name='IKE Policy 2'),
IKEPolicy(name='IKE Policy 3'),
)
IKEPolicy.objects.bulk_create(ike_policies)
ike_policies[0].proposals.add(ike_proposals[0])
ike_policies[1].proposals.add(ike_proposals[1])
ike_policies[2].proposals.add(ike_proposals[2])
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -342,6 +352,13 @@ class IKEProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']} params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ike_policy(self):
ike_policies = IKEPolicy.objects.all()[:2]
params = {'ike_policy_id': [ike_policies[0].pk, ike_policies[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'ike_policy': [ike_policies[0].name, ike_policies[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_authentication_method(self): def test_authentication_method(self):
params = {'authentication_method': [ params = {'authentication_method': [
AuthenticationMethodChoices.PRESHARED_KEYS, AuthenticationMethodChoices.CERTIFICATES AuthenticationMethodChoices.PRESHARED_KEYS, AuthenticationMethodChoices.CERTIFICATES
@ -487,6 +504,16 @@ class IPSecProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
) )
IPSecProposal.objects.bulk_create(ipsec_proposals) IPSecProposal.objects.bulk_create(ipsec_proposals)
ipsec_policies = (
IPSecPolicy(name='IPSec Policy 1'),
IPSecPolicy(name='IPSec Policy 2'),
IPSecPolicy(name='IPSec Policy 3'),
)
IPSecPolicy.objects.bulk_create(ipsec_policies)
ipsec_policies[0].proposals.add(ipsec_proposals[0])
ipsec_policies[1].proposals.add(ipsec_proposals[1])
ipsec_policies[2].proposals.add(ipsec_proposals[2])
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -499,6 +526,13 @@ class IPSecProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']} params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ipsec_policy(self):
ipsec_policies = IPSecPolicy.objects.all()[:2]
params = {'ipsec_policy_id': [ipsec_policies[0].pk, ipsec_policies[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'ipsec_policy': [ipsec_policies[0].name, ipsec_policies[1].name]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_encryption_algorithm(self): def test_encryption_algorithm(self):
params = {'encryption_algorithm': [ params = {'encryption_algorithm': [
EncryptionAlgorithmChoices.ENCRYPTION_AES128_CBC, EncryptionAlgorithmChoices.ENCRYPTION_AES192_CBC EncryptionAlgorithmChoices.ENCRYPTION_AES128_CBC, EncryptionAlgorithmChoices.ENCRYPTION_AES192_CBC
@ -716,7 +750,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
return 'import_target' return 'import_target'
if field.name == 'export_targets': if field.name == 'export_targets':
return 'export_target' return 'export_target'
return super().get_m2m_filter_name(field) return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):

View File

@ -2,6 +2,7 @@ import django_filters
from django.db.models import Q from django.db.models import Q
from dcim.choices import LinkStatusChoices from dcim.choices import LinkStatusChoices
from dcim.models import Interface
from ipam.models import VLAN from ipam.models import VLAN
from netbox.filtersets import OrganizationalModelFilterSet, NetBoxModelFilterSet from netbox.filtersets import OrganizationalModelFilterSet, NetBoxModelFilterSet
from tenancy.filtersets import TenancyFilterSet from tenancy.filtersets import TenancyFilterSet
@ -60,6 +61,10 @@ class WirelessLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
vlan_id = django_filters.ModelMultipleChoiceFilter( vlan_id = django_filters.ModelMultipleChoiceFilter(
queryset=VLAN.objects.all() queryset=VLAN.objects.all()
) )
interface_id = django_filters.ModelMultipleChoiceFilter(
queryset=Interface.objects.all(),
field_name='interfaces'
)
auth_type = django_filters.MultipleChoiceFilter( auth_type = django_filters.MultipleChoiceFilter(
choices=WirelessAuthTypeChoices choices=WirelessAuthTypeChoices
) )

View File

@ -153,6 +153,17 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
) )
WirelessLAN.objects.bulk_create(wireless_lans) WirelessLAN.objects.bulk_create(wireless_lans)
device = create_test_device('Device 1')
interfaces = (
Interface(device=device, name='Interface 1', type=InterfaceTypeChoices.TYPE_80211N),
Interface(device=device, name='Interface 2', type=InterfaceTypeChoices.TYPE_80211N),
Interface(device=device, name='Interface 3', type=InterfaceTypeChoices.TYPE_80211N),
)
Interface.objects.bulk_create(interfaces)
interfaces[0].wireless_lans.add(wireless_lans[0])
interfaces[1].wireless_lans.add(wireless_lans[1])
interfaces[2].wireless_lans.add(wireless_lans[2])
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -200,6 +211,11 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'tenant': [tenants[0].slug, tenants[1].slug]} params = {'tenant': [tenants[0].slug, tenants[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_interface(self):
interfaces = Interface.objects.all()[:2]
params = {'interface_id': [interfaces[0].pk, interfaces[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
class WirelessLinkTestCase(TestCase, ChangeLoggedFilterSetTests): class WirelessLinkTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = WirelessLink.objects.all() queryset = WirelessLink.objects.all()