From d6acc18c29b776145e045309b5a45e1e05a922ce Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Mon, 11 Mar 2024 12:32:06 -0400 Subject: [PATCH] Closes #15383: Standardize filtering logic for the parents of recursively-nested models --- netbox/dcim/filtersets.py | 40 ++++++- netbox/dcim/tests/test_filtersets.py | 131 ++++++++++++++++------- netbox/tenancy/filtersets.py | 30 +++++- netbox/tenancy/tests/test_filtersets.py | 94 ++++++++++------ netbox/wireless/filtersets.py | 11 ++ netbox/wireless/tests/test_filtersets.py | 46 +++++--- 6 files changed, 262 insertions(+), 90 deletions(-) diff --git a/netbox/dcim/filtersets.py b/netbox/dcim/filtersets.py index 6b1611694..082659b8f 100644 --- a/netbox/dcim/filtersets.py +++ b/netbox/dcim/filtersets.py @@ -89,6 +89,19 @@ class RegionFilterSet(OrganizationalModelFilterSet, ContactModelFilterSet): to_field_name='slug', label=_('Parent region (slug)'), ) + ancestor_id = TreeNodeMultipleChoiceFilter( + queryset=Region.objects.all(), + field_name='parent', + lookup_expr='in', + label=_('Region (ID)'), + ) + ancestor = TreeNodeMultipleChoiceFilter( + queryset=Region.objects.all(), + field_name='parent', + lookup_expr='in', + to_field_name='slug', + label=_('Region (slug)'), + ) class Meta: model = Region @@ -106,6 +119,19 @@ class SiteGroupFilterSet(OrganizationalModelFilterSet, ContactModelFilterSet): to_field_name='slug', label=_('Parent site group (slug)'), ) + ancestor_id = TreeNodeMultipleChoiceFilter( + queryset=SiteGroup.objects.all(), + field_name='parent', + lookup_expr='in', + label=_('Site group (ID)'), + ) + ancestor = TreeNodeMultipleChoiceFilter( + queryset=SiteGroup.objects.all(), + field_name='parent', + lookup_expr='in', + to_field_name='slug', + label=_('Site group (slug)'), + ) class Meta: model = SiteGroup @@ -214,13 +240,23 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM to_field_name='slug', label=_('Site (slug)'), ) - parent_id = TreeNodeMultipleChoiceFilter( + parent_id = django_filters.ModelMultipleChoiceFilter( + queryset=Location.objects.all(), + label=_('Parent location (ID)'), + ) + parent = django_filters.ModelMultipleChoiceFilter( + field_name='parent__slug', + queryset=Location.objects.all(), + to_field_name='slug', + label=_('Parent location (slug)'), + ) + ancestor_id = TreeNodeMultipleChoiceFilter( queryset=Location.objects.all(), field_name='parent', lookup_expr='in', label=_('Location (ID)'), ) - parent = TreeNodeMultipleChoiceFilter( + ancestor = TreeNodeMultipleChoiceFilter( queryset=Location.objects.all(), field_name='parent', lookup_expr='in', diff --git a/netbox/dcim/tests/test_filtersets.py b/netbox/dcim/tests/test_filtersets.py index b255c283e..f1eeddbb5 100644 --- a/netbox/dcim/tests/test_filtersets.py +++ b/netbox/dcim/tests/test_filtersets.py @@ -64,21 +64,32 @@ class RegionTestCase(TestCase, ChangeLoggedFilterSetTests): @classmethod def setUpTestData(cls): - regions = ( + parent_regions = ( Region(name='Region 1', slug='region-1', description='foobar1'), Region(name='Region 2', slug='region-2', description='foobar2'), Region(name='Region 3', slug='region-3', description='foobar3'), ) + for region in parent_regions: + region.save() + + regions = ( + Region(name='Region 1A', slug='region-1a', parent=parent_regions[0]), + Region(name='Region 1B', slug='region-1b', parent=parent_regions[0]), + Region(name='Region 2A', slug='region-2a', parent=parent_regions[1]), + Region(name='Region 2B', slug='region-2b', parent=parent_regions[1]), + Region(name='Region 3A', slug='region-3a', parent=parent_regions[2]), + Region(name='Region 3B', slug='region-3b', parent=parent_regions[2]), + ) for region in regions: region.save() child_regions = ( - Region(name='Region 1A', slug='region-1a', parent=regions[0]), - Region(name='Region 1B', slug='region-1b', parent=regions[0]), - Region(name='Region 2A', slug='region-2a', parent=regions[1]), - Region(name='Region 2B', slug='region-2b', parent=regions[1]), - Region(name='Region 3A', slug='region-3a', parent=regions[2]), - Region(name='Region 3B', slug='region-3b', parent=regions[2]), + Region(name='Region 1A1', slug='region-1a1', parent=regions[0]), + Region(name='Region 1B1', slug='region-1b1', parent=regions[1]), + Region(name='Region 2A1', slug='region-2a1', parent=regions[2]), + Region(name='Region 2B1', slug='region-2b1', parent=regions[3]), + Region(name='Region 3A1', slug='region-3a1', parent=regions[4]), + Region(name='Region 3B1', slug='region-3b1', parent=regions[5]), ) for region in child_regions: region.save() @@ -100,12 +111,19 @@ class RegionTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_parent(self): - parent_regions = Region.objects.filter(parent__isnull=True)[:2] - params = {'parent_id': [parent_regions[0].pk, parent_regions[1].pk]} + regions = Region.objects.filter(parent__isnull=True)[:2] + params = {'parent_id': [regions[0].pk, regions[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) - params = {'parent': [parent_regions[0].slug, parent_regions[1].slug]} + params = {'parent': [regions[0].slug, regions[1].slug]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + def test_ancestor(self): + regions = Region.objects.filter(parent__isnull=True)[:2] + params = {'ancestor_id': [regions[0].pk, regions[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + params = {'ancestor': [regions[0].slug, regions[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = SiteGroup.objects.all() @@ -114,24 +132,35 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests): @classmethod def setUpTestData(cls): - sitegroups = ( + parent_groups = ( SiteGroup(name='Site Group 1', slug='site-group-1', description='foobar1'), SiteGroup(name='Site Group 2', slug='site-group-2', description='foobar2'), SiteGroup(name='Site Group 3', slug='site-group-3', description='foobar3'), ) - for sitegroup in sitegroups: - sitegroup.save() + for site_group in parent_groups: + site_group.save() - child_sitegroups = ( - SiteGroup(name='Site Group 1A', slug='site-group-1a', parent=sitegroups[0]), - SiteGroup(name='Site Group 1B', slug='site-group-1b', parent=sitegroups[0]), - SiteGroup(name='Site Group 2A', slug='site-group-2a', parent=sitegroups[1]), - SiteGroup(name='Site Group 2B', slug='site-group-2b', parent=sitegroups[1]), - SiteGroup(name='Site Group 3A', slug='site-group-3a', parent=sitegroups[2]), - SiteGroup(name='Site Group 3B', slug='site-group-3b', parent=sitegroups[2]), + groups = ( + SiteGroup(name='Site Group 1A', slug='site-group-1a', parent=parent_groups[0]), + SiteGroup(name='Site Group 1B', slug='site-group-1b', parent=parent_groups[0]), + SiteGroup(name='Site Group 2A', slug='site-group-2a', parent=parent_groups[1]), + SiteGroup(name='Site Group 2B', slug='site-group-2b', parent=parent_groups[1]), + SiteGroup(name='Site Group 3A', slug='site-group-3a', parent=parent_groups[2]), + SiteGroup(name='Site Group 3B', slug='site-group-3b', parent=parent_groups[2]), ) - for sitegroup in child_sitegroups: - sitegroup.save() + for site_group in groups: + site_group.save() + + child_groups = ( + SiteGroup(name='Site Group 1A1', slug='site-group-1a1', parent=groups[0]), + SiteGroup(name='Site Group 1B1', slug='site-group-1b1', parent=groups[1]), + SiteGroup(name='Site Group 2A1', slug='site-group-2a1', parent=groups[2]), + SiteGroup(name='Site Group 2B1', slug='site-group-2b1', parent=groups[3]), + SiteGroup(name='Site Group 3A1', slug='site-group-3a1', parent=groups[4]), + SiteGroup(name='Site Group 3B1', slug='site-group-3b1', parent=groups[5]), + ) + for site_group in child_groups: + site_group.save() def test_q(self): params = {'q': 'foobar1'} @@ -150,12 +179,19 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_parent(self): - parent_sitegroups = SiteGroup.objects.filter(parent__isnull=True)[:2] - params = {'parent_id': [parent_sitegroups[0].pk, parent_sitegroups[1].pk]} + site_groups = SiteGroup.objects.filter(parent__isnull=True)[:2] + params = {'parent_id': [site_groups[0].pk, site_groups[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) - params = {'parent': [parent_sitegroups[0].slug, parent_sitegroups[1].slug]} + params = {'parent': [site_groups[0].slug, site_groups[1].slug]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + def test_ancestor(self): + site_groups = SiteGroup.objects.filter(parent__isnull=True)[:2] + params = {'ancestor_id': [site_groups[0].pk, site_groups[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + params = {'ancestor': [site_groups[0].slug, site_groups[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + class SiteTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = Site.objects.all() @@ -314,21 +350,29 @@ class LocationTestCase(TestCase, ChangeLoggedFilterSetTests): Site.objects.bulk_create(sites) parent_locations = ( - Location(name='Parent Location 1', slug='parent-location-1', site=sites[0]), - Location(name='Parent Location 2', slug='parent-location-2', site=sites[1]), - Location(name='Parent Location 3', slug='parent-location-3', site=sites[2]), + Location(name='Location 1', slug='location-1', site=sites[0]), + Location(name='Location 2', slug='location-2', site=sites[1]), + Location(name='Location 3', slug='location-3', site=sites[2]), ) for location in parent_locations: location.save() locations = ( - Location(name='Location 1', slug='location-1', site=sites[0], parent=parent_locations[0], status=LocationStatusChoices.STATUS_PLANNED, description='foobar1'), - Location(name='Location 2', slug='location-2', site=sites[1], parent=parent_locations[1], status=LocationStatusChoices.STATUS_STAGING, description='foobar2'), - Location(name='Location 3', slug='location-3', site=sites[2], parent=parent_locations[2], status=LocationStatusChoices.STATUS_DECOMMISSIONING, description='foobar3'), + Location(name='Location 1A', slug='location-1a', site=sites[0], parent=parent_locations[0], status=LocationStatusChoices.STATUS_PLANNED, description='foobar1'), + Location(name='Location 2A', slug='location-2a', site=sites[1], parent=parent_locations[1], status=LocationStatusChoices.STATUS_STAGING, description='foobar2'), + Location(name='Location 3A', slug='location-3a', site=sites[2], parent=parent_locations[2], status=LocationStatusChoices.STATUS_DECOMMISSIONING, description='foobar3'), ) for location in locations: location.save() + child_locations = ( + Location(name='Location 1A1', slug='location-1a1', site=sites[0], parent=locations[0]), + Location(name='Location 2A1', slug='location-2a1', site=sites[1], parent=locations[1]), + Location(name='Location 3A1', slug='location-3a1', site=sites[2], parent=locations[2]), + ) + for location in child_locations: + location.save() + def test_q(self): params = {'q': 'foobar1'} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) @@ -352,31 +396,38 @@ class LocationTestCase(TestCase, ChangeLoggedFilterSetTests): def test_region(self): regions = Region.objects.all()[:2] params = {'region_id': [regions[0].pk, regions[1].pk]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) params = {'region': [regions[0].slug, regions[1].slug]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) def test_site_group(self): site_groups = SiteGroup.objects.all()[:2] params = {'site_group_id': [site_groups[0].pk, site_groups[1].pk]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) params = {'site_group': [site_groups[0].slug, site_groups[1].slug]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) def test_site(self): sites = Site.objects.all()[:2] params = {'site_id': [sites[0].pk, sites[1].pk]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) params = {'site': [sites[0].slug, sites[1].slug]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6) def test_parent(self): - parent_groups = Location.objects.filter(name__startswith='Parent')[:2] - params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} + locations = Location.objects.filter(parent__isnull=True)[:2] + params = {'parent_id': [locations[0].pk, locations[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) - params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} + params = {'parent': [locations[0].slug, locations[1].slug]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + def test_ancestor(self): + locations = Location.objects.filter(parent__isnull=True)[:2] + params = {'ancestor_id': [locations[0].pk, locations[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + params = {'ancestor': [locations[0].slug, locations[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + class RackRoleTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = RackRole.objects.all() diff --git a/netbox/tenancy/filtersets.py b/netbox/tenancy/filtersets.py index 295d20774..7af3dc082 100644 --- a/netbox/tenancy/filtersets.py +++ b/netbox/tenancy/filtersets.py @@ -26,12 +26,25 @@ __all__ = ( class ContactGroupFilterSet(OrganizationalModelFilterSet): parent_id = django_filters.ModelMultipleChoiceFilter( queryset=ContactGroup.objects.all(), - label=_('Contact group (ID)'), + label=_('Parent contact group (ID)'), ) parent = django_filters.ModelMultipleChoiceFilter( field_name='parent__slug', queryset=ContactGroup.objects.all(), to_field_name='slug', + label=_('Parent contact group (slug)'), + ) + ancestor_id = TreeNodeMultipleChoiceFilter( + queryset=ContactGroup.objects.all(), + field_name='parent', + lookup_expr='in', + label=_('Contact group (ID)'), + ) + ancestor = TreeNodeMultipleChoiceFilter( + queryset=ContactGroup.objects.all(), + field_name='parent', + lookup_expr='in', + to_field_name='slug', label=_('Contact group (slug)'), ) @@ -155,12 +168,25 @@ class ContactModelFilterSet(django_filters.FilterSet): class TenantGroupFilterSet(OrganizationalModelFilterSet): parent_id = django_filters.ModelMultipleChoiceFilter( queryset=TenantGroup.objects.all(), - label=_('Tenant group (ID)'), + label=_('Parent tenant group (ID)'), ) parent = django_filters.ModelMultipleChoiceFilter( field_name='parent__slug', queryset=TenantGroup.objects.all(), to_field_name='slug', + label=_('Parent tenant group (slug)'), + ) + ancestor_id = TreeNodeMultipleChoiceFilter( + queryset=TenantGroup.objects.all(), + field_name='parent', + lookup_expr='in', + label=_('Tenant group (ID)'), + ) + ancestor = TreeNodeMultipleChoiceFilter( + queryset=TenantGroup.objects.all(), + field_name='parent', + lookup_expr='in', + to_field_name='slug', label=_('Tenant group (slug)'), ) diff --git a/netbox/tenancy/tests/test_filtersets.py b/netbox/tenancy/tests/test_filtersets.py index 3bcbddd4b..f6890a3d4 100644 --- a/netbox/tenancy/tests/test_filtersets.py +++ b/netbox/tenancy/tests/test_filtersets.py @@ -15,35 +15,43 @@ class TenantGroupTestCase(TestCase, ChangeLoggedFilterSetTests): def setUpTestData(cls): parent_tenant_groups = ( - TenantGroup(name='Parent Tenant Group 1', slug='parent-tenant-group-1'), - TenantGroup(name='Parent Tenant Group 2', slug='parent-tenant-group-2'), - TenantGroup(name='Parent Tenant Group 3', slug='parent-tenant-group-3'), + TenantGroup(name='Tenant Group 1', slug='tenant-group-1'), + TenantGroup(name='Tenant Group 2', slug='tenant-group-2'), + TenantGroup(name='Tenant Group 3', slug='tenant-group-3'), ) - for tenantgroup in parent_tenant_groups: - tenantgroup.save() + for tenant_group in parent_tenant_groups: + tenant_group.save() tenant_groups = ( TenantGroup( - name='Tenant Group 1', - slug='tenant-group-1', + name='Tenant Group 1A', + slug='tenant-group-1a', parent=parent_tenant_groups[0], description='foobar1' ), TenantGroup( - name='Tenant Group 2', - slug='tenant-group-2', + name='Tenant Group 2A', + slug='tenant-group-2a', parent=parent_tenant_groups[1], description='foobar2' ), TenantGroup( - name='Tenant Group 3', - slug='tenant-group-3', + name='Tenant Group 3A', + slug='tenant-group-3a', parent=parent_tenant_groups[2], description='foobar3' ), ) - for tenantgroup in tenant_groups: - tenantgroup.save() + for tenant_group in tenant_groups: + tenant_group.save() + + child_tenant_groups = ( + TenantGroup(name='Tenant Group 1A1', slug='tenant-group-1a1', parent=tenant_groups[0]), + TenantGroup(name='Tenant Group 2A1', slug='tenant-group-2a1', parent=tenant_groups[1]), + TenantGroup(name='Tenant Group 3A1', slug='tenant-group-3a1', parent=tenant_groups[2]), + ) + for tenant_group in child_tenant_groups: + tenant_group.save() def test_q(self): params = {'q': 'foobar1'} @@ -62,12 +70,19 @@ class TenantGroupTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_parent(self): - parent_groups = TenantGroup.objects.filter(name__startswith='Parent')[:2] - params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} + tenant_groups = TenantGroup.objects.filter(parent__isnull=True)[:2] + params = {'parent_id': [tenant_groups[0].pk, tenant_groups[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) - params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} + params = {'parent': [tenant_groups[0].slug, tenant_groups[1].slug]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + def test_ancestor(self): + tenant_groups = TenantGroup.objects.filter(parent__isnull=True)[:2] + params = {'ancestor_id': [tenant_groups[0].pk, tenant_groups[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + params = {'ancestor': [tenant_groups[0].slug, tenant_groups[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + class TenantTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = Tenant.objects.all() @@ -123,35 +138,43 @@ class ContactGroupTestCase(TestCase, ChangeLoggedFilterSetTests): def setUpTestData(cls): parent_contact_groups = ( - ContactGroup(name='Parent Contact Group 1', slug='parent-contact-group-1'), - ContactGroup(name='Parent Contact Group 2', slug='parent-contact-group-2'), - ContactGroup(name='Parent Contact Group 3', slug='parent-contact-group-3'), + ContactGroup(name='Contact Group 1', slug='contact-group-1'), + ContactGroup(name='Contact Group 2', slug='contact-group-2'), + ContactGroup(name='Contact Group 3', slug='contact-group-3'), ) - for contactgroup in parent_contact_groups: - contactgroup.save() + for contact_group in parent_contact_groups: + contact_group.save() contact_groups = ( ContactGroup( - name='Contact Group 1', - slug='contact-group-1', + name='Contact Group 1A', + slug='contact-group-1a', parent=parent_contact_groups[0], description='foobar1' ), ContactGroup( - name='Contact Group 2', - slug='contact-group-2', + name='Contact Group 2A', + slug='contact-group-2a', parent=parent_contact_groups[1], description='foobar2' ), ContactGroup( - name='Contact Group 3', - slug='contact-group-3', + name='Contact Group 3A', + slug='contact-group-3a', parent=parent_contact_groups[2], description='foobar3' ), ) - for contactgroup in contact_groups: - contactgroup.save() + for contact_group in contact_groups: + contact_group.save() + + child_contact_groups = ( + ContactGroup(name='Contact Group 1A1', slug='contact-group-1a1', parent=contact_groups[0]), + ContactGroup(name='Contact Group 2A1', slug='contact-group-2a1', parent=contact_groups[1]), + ContactGroup(name='Contact Group 3A1', slug='contact-group-3a1', parent=contact_groups[2]), + ) + for contact_group in child_contact_groups: + contact_group.save() def test_q(self): params = {'q': 'foobar1'} @@ -170,12 +193,19 @@ class ContactGroupTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_parent(self): - parent_groups = ContactGroup.objects.filter(parent__isnull=True)[:2] - params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} + contact_groups = ContactGroup.objects.filter(parent__isnull=True)[:2] + params = {'parent_id': [contact_groups[0].pk, contact_groups[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) - params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} + params = {'parent': [contact_groups[0].slug, contact_groups[1].slug]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + def test_ancestor(self): + contact_groups = ContactGroup.objects.filter(parent__isnull=True)[:2] + params = {'ancestor_id': [contact_groups[0].pk, contact_groups[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + params = {'ancestor': [contact_groups[0].slug, contact_groups[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + class ContactRoleTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = ContactRole.objects.all() diff --git a/netbox/wireless/filtersets.py b/netbox/wireless/filtersets.py index 6ffb9cb91..50b1f78b1 100644 --- a/netbox/wireless/filtersets.py +++ b/netbox/wireless/filtersets.py @@ -25,6 +25,17 @@ class WirelessLANGroupFilterSet(OrganizationalModelFilterSet): queryset=WirelessLANGroup.objects.all(), to_field_name='slug' ) + ancestor_id = TreeNodeMultipleChoiceFilter( + queryset=WirelessLANGroup.objects.all(), + field_name='parent', + lookup_expr='in' + ) + ancestor = TreeNodeMultipleChoiceFilter( + queryset=WirelessLANGroup.objects.all(), + field_name='parent', + lookup_expr='in', + to_field_name='slug' + ) class Meta: model = WirelessLANGroup diff --git a/netbox/wireless/tests/test_filtersets.py b/netbox/wireless/tests/test_filtersets.py index 4184d5392..78e50edb7 100644 --- a/netbox/wireless/tests/test_filtersets.py +++ b/netbox/wireless/tests/test_filtersets.py @@ -17,21 +17,32 @@ class WirelessLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests): @classmethod def setUpTestData(cls): - groups = ( + parent_groups = ( WirelessLANGroup(name='Wireless LAN Group 1', slug='wireless-lan-group-1', description='A'), WirelessLANGroup(name='Wireless LAN Group 2', slug='wireless-lan-group-2', description='B'), WirelessLANGroup(name='Wireless LAN Group 3', slug='wireless-lan-group-3', description='C'), ) + for group in parent_groups: + group.save() + + groups = ( + WirelessLANGroup(name='Wireless LAN Group 1A', slug='wireless-lan-group-1a', parent=parent_groups[0], description='foobar1'), + WirelessLANGroup(name='Wireless LAN Group 1B', slug='wireless-lan-group-1b', parent=parent_groups[0], description='foobar2'), + WirelessLANGroup(name='Wireless LAN Group 2A', slug='wireless-lan-group-2a', parent=parent_groups[1]), + WirelessLANGroup(name='Wireless LAN Group 2B', slug='wireless-lan-group-2b', parent=parent_groups[1]), + WirelessLANGroup(name='Wireless LAN Group 3A', slug='wireless-lan-group-3a', parent=parent_groups[2]), + WirelessLANGroup(name='Wireless LAN Group 3B', slug='wireless-lan-group-3b', parent=parent_groups[2]), + ) for group in groups: group.save() child_groups = ( - WirelessLANGroup(name='Wireless LAN Group 1A', slug='wireless-lan-group-1a', parent=groups[0], description='foobar1'), - WirelessLANGroup(name='Wireless LAN Group 1B', slug='wireless-lan-group-1b', parent=groups[0], description='foobar2'), - WirelessLANGroup(name='Wireless LAN Group 2A', slug='wireless-lan-group-2a', parent=groups[1]), - WirelessLANGroup(name='Wireless LAN Group 2B', slug='wireless-lan-group-2b', parent=groups[1]), - WirelessLANGroup(name='Wireless LAN Group 3A', slug='wireless-lan-group-3a', parent=groups[2]), - WirelessLANGroup(name='Wireless LAN Group 3B', slug='wireless-lan-group-3b', parent=groups[2]), + WirelessLANGroup(name='Wireless LAN Group 1A1', slug='wireless-lan-group-1a1', parent=groups[0]), + WirelessLANGroup(name='Wireless LAN Group 1B1', slug='wireless-lan-group-1b1', parent=groups[1]), + WirelessLANGroup(name='Wireless LAN Group 2A1', slug='wireless-lan-group-2a1', parent=groups[2]), + WirelessLANGroup(name='Wireless LAN Group 2B1', slug='wireless-lan-group-2b1', parent=groups[3]), + WirelessLANGroup(name='Wireless LAN Group 3A1', slug='wireless-lan-group-3a1', parent=groups[4]), + WirelessLANGroup(name='Wireless LAN Group 3B1', slug='wireless-lan-group-3b1', parent=groups[5]), ) for group in child_groups: group.save() @@ -48,17 +59,24 @@ class WirelessLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests): params = {'slug': ['wireless-lan-group-1', 'wireless-lan-group-2']} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) - def test_parent(self): - parent_groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2] - params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) - params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} - self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) - def test_description(self): params = {'description': ['foobar1', 'foobar2']} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + def test_parent(self): + groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2] + params = {'parent_id': [groups[0].pk, groups[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + params = {'parent': [groups[0].slug, groups[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) + + def test_ancestor(self): + groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2] + params = {'ancestor_id': [groups[0].pk, groups[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + params = {'ancestor': [groups[0].slug, groups[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = WirelessLAN.objects.all()