Implement contains_vid filter

This commit is contained in:
Jeremy Stretch 2024-07-13 12:06:04 -04:00
parent 674907ce0f
commit 45183e4066
3 changed files with 72 additions and 14 deletions

View File

@ -911,8 +911,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet):
cluster = django_filters.NumberFilter(
method='filter_scope'
)
vlan_id = django_filters.NumberFilter(
method='filter_vlan_id'
contains_vid = django_filters.NumberFilter(
method='filter_contains_vid'
)
class Meta:
@ -935,9 +935,19 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet):
scope_id=value
)
def filter_vlan_id(self, queryset, name, value):
def filter_contains_vid(self, queryset, name, value):
"""
Return all VLANGroups which contain the given VLAN ID.
"""
table_name = VLANGroup._meta.db_table
# TODO: See if this can be optimized without compromising queryset integrity
# Expand VLAN ID ranges to query by integer
groups = VLANGroup.objects.raw(
f'SELECT id FROM {table_name}, unnest(vlan_id_ranges) vid_range WHERE %s <@ vid_range',
params=(value,)
)
return queryset.filter(
vid_range__contained_by=value
pk__in=[g.id for g in groups]
)

View File

@ -413,7 +413,7 @@ class VLANGroupFilterForm(NetBoxModelFilterSetForm):
FieldSet('q', 'filter_id', 'tag'),
FieldSet('region', 'sitegroup', 'site', 'location', 'rack', name=_('Location')),
FieldSet('cluster_group', 'cluster', name=_('Cluster')),
FieldSet('vlan_id', name=_('VLAN ID')),
FieldSet('contains_vid', name=_('VLANs')),
)
model = VLANGroup
region = DynamicModelMultipleChoiceField(
@ -451,7 +451,7 @@ class VLANGroupFilterForm(NetBoxModelFilterSetForm):
required=False,
label=_('Cluster group')
)
vlan_id = forms.IntegerField(
contains_vid = forms.IntegerField(
min_value=0,
required=False,
label=_('Contains VLAN ID')

View File

@ -1,4 +1,5 @@
from django.contrib.contenttypes.models import ContentType
from django.db.backends.postgresql.psycopg_any import NumericRange
from django.test import TestCase
from netaddr import IPNetwork
@ -1495,14 +1496,55 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
cluster.save()
vlan_groups = (
VLANGroup(name='VLAN Group 1', slug='vlan-group-1', scope=region, description='foobar1'),
VLANGroup(name='VLAN Group 2', slug='vlan-group-2', scope=sitegroup, description='foobar2'),
VLANGroup(name='VLAN Group 3', slug='vlan-group-3', scope=site, description='foobar3'),
VLANGroup(name='VLAN Group 4', slug='vlan-group-4', scope=location),
VLANGroup(name='VLAN Group 5', slug='vlan-group-5', scope=rack),
VLANGroup(name='VLAN Group 6', slug='vlan-group-6', scope=clustergroup),
VLANGroup(name='VLAN Group 7', slug='vlan-group-7', scope=cluster),
VLANGroup(name='VLAN Group 8', slug='vlan-group-8'),
VLANGroup(
name='VLAN Group 1',
slug='vlan-group-1',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(100, 200)],
scope=region,
description='foobar1'
),
VLANGroup(
name='VLAN Group 2',
slug='vlan-group-2',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(200, 300)],
scope=sitegroup,
description='foobar2'
),
VLANGroup(
name='VLAN Group 3',
slug='vlan-group-3',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(300, 400)],
scope=site,
description='foobar3'
),
VLANGroup(
name='VLAN Group 4',
slug='vlan-group-4',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(400, 500)],
scope=location
),
VLANGroup(
name='VLAN Group 5',
slug='vlan-group-5',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(500, 600)],
scope=rack
),
VLANGroup(
name='VLAN Group 6',
slug='vlan-group-6',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(600, 700)],
scope=clustergroup
),
VLANGroup(
name='VLAN Group 7',
slug='vlan-group-7',
vlan_id_ranges=[NumericRange(1, 11), NumericRange(700, 800)],
scope=cluster
),
VLANGroup(
name='VLAN Group 8',
slug='vlan-group-8'
),
)
VLANGroup.objects.bulk_create(vlan_groups)
@ -1522,6 +1564,12 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def text_contains_vid(self):
params = {'contains_vid': 123}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'contains_vid': 1}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
def test_region(self):
params = {'region': Region.objects.first().pk}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)