Closes #15383: Standardize filtering logic for the parents of recursively-nested models

This commit is contained in:
Jeremy Stretch 2024-03-11 12:32:06 -04:00
parent 21de3f954f
commit d6acc18c29
6 changed files with 262 additions and 90 deletions

View File

@ -89,6 +89,19 @@ class RegionFilterSet(OrganizationalModelFilterSet, ContactModelFilterSet):
to_field_name='slug', to_field_name='slug',
label=_('Parent region (slug)'), label=_('Parent region (slug)'),
) )
ancestor_id = TreeNodeMultipleChoiceFilter(
queryset=Region.objects.all(),
field_name='parent',
lookup_expr='in',
label=_('Region (ID)'),
)
ancestor = TreeNodeMultipleChoiceFilter(
queryset=Region.objects.all(),
field_name='parent',
lookup_expr='in',
to_field_name='slug',
label=_('Region (slug)'),
)
class Meta: class Meta:
model = Region model = Region
@ -106,6 +119,19 @@ class SiteGroupFilterSet(OrganizationalModelFilterSet, ContactModelFilterSet):
to_field_name='slug', to_field_name='slug',
label=_('Parent site group (slug)'), label=_('Parent site group (slug)'),
) )
ancestor_id = TreeNodeMultipleChoiceFilter(
queryset=SiteGroup.objects.all(),
field_name='parent',
lookup_expr='in',
label=_('Site group (ID)'),
)
ancestor = TreeNodeMultipleChoiceFilter(
queryset=SiteGroup.objects.all(),
field_name='parent',
lookup_expr='in',
to_field_name='slug',
label=_('Site group (slug)'),
)
class Meta: class Meta:
model = SiteGroup model = SiteGroup
@ -214,13 +240,23 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM
to_field_name='slug', to_field_name='slug',
label=_('Site (slug)'), label=_('Site (slug)'),
) )
parent_id = TreeNodeMultipleChoiceFilter( parent_id = django_filters.ModelMultipleChoiceFilter(
queryset=Location.objects.all(),
label=_('Parent location (ID)'),
)
parent = django_filters.ModelMultipleChoiceFilter(
field_name='parent__slug',
queryset=Location.objects.all(),
to_field_name='slug',
label=_('Parent location (slug)'),
)
ancestor_id = TreeNodeMultipleChoiceFilter(
queryset=Location.objects.all(), queryset=Location.objects.all(),
field_name='parent', field_name='parent',
lookup_expr='in', lookup_expr='in',
label=_('Location (ID)'), label=_('Location (ID)'),
) )
parent = TreeNodeMultipleChoiceFilter( ancestor = TreeNodeMultipleChoiceFilter(
queryset=Location.objects.all(), queryset=Location.objects.all(),
field_name='parent', field_name='parent',
lookup_expr='in', lookup_expr='in',

View File

@ -64,21 +64,32 @@ class RegionTestCase(TestCase, ChangeLoggedFilterSetTests):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
regions = ( parent_regions = (
Region(name='Region 1', slug='region-1', description='foobar1'), Region(name='Region 1', slug='region-1', description='foobar1'),
Region(name='Region 2', slug='region-2', description='foobar2'), Region(name='Region 2', slug='region-2', description='foobar2'),
Region(name='Region 3', slug='region-3', description='foobar3'), Region(name='Region 3', slug='region-3', description='foobar3'),
) )
for region in parent_regions:
region.save()
regions = (
Region(name='Region 1A', slug='region-1a', parent=parent_regions[0]),
Region(name='Region 1B', slug='region-1b', parent=parent_regions[0]),
Region(name='Region 2A', slug='region-2a', parent=parent_regions[1]),
Region(name='Region 2B', slug='region-2b', parent=parent_regions[1]),
Region(name='Region 3A', slug='region-3a', parent=parent_regions[2]),
Region(name='Region 3B', slug='region-3b', parent=parent_regions[2]),
)
for region in regions: for region in regions:
region.save() region.save()
child_regions = ( child_regions = (
Region(name='Region 1A', slug='region-1a', parent=regions[0]), Region(name='Region 1A1', slug='region-1a1', parent=regions[0]),
Region(name='Region 1B', slug='region-1b', parent=regions[0]), Region(name='Region 1B1', slug='region-1b1', parent=regions[1]),
Region(name='Region 2A', slug='region-2a', parent=regions[1]), Region(name='Region 2A1', slug='region-2a1', parent=regions[2]),
Region(name='Region 2B', slug='region-2b', parent=regions[1]), Region(name='Region 2B1', slug='region-2b1', parent=regions[3]),
Region(name='Region 3A', slug='region-3a', parent=regions[2]), Region(name='Region 3A1', slug='region-3a1', parent=regions[4]),
Region(name='Region 3B', slug='region-3b', parent=regions[2]), Region(name='Region 3B1', slug='region-3b1', parent=regions[5]),
) )
for region in child_regions: for region in child_regions:
region.save() region.save()
@ -100,12 +111,19 @@ class RegionTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_parent(self): def test_parent(self):
parent_regions = Region.objects.filter(parent__isnull=True)[:2] regions = Region.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [parent_regions[0].pk, parent_regions[1].pk]} params = {'parent_id': [regions[0].pk, regions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'parent': [parent_regions[0].slug, parent_regions[1].slug]} params = {'parent': [regions[0].slug, regions[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_ancestor(self):
regions = Region.objects.filter(parent__isnull=True)[:2]
params = {'ancestor_id': [regions[0].pk, regions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
params = {'ancestor': [regions[0].slug, regions[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests): class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = SiteGroup.objects.all() queryset = SiteGroup.objects.all()
@ -114,24 +132,35 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
sitegroups = ( parent_groups = (
SiteGroup(name='Site Group 1', slug='site-group-1', description='foobar1'), SiteGroup(name='Site Group 1', slug='site-group-1', description='foobar1'),
SiteGroup(name='Site Group 2', slug='site-group-2', description='foobar2'), SiteGroup(name='Site Group 2', slug='site-group-2', description='foobar2'),
SiteGroup(name='Site Group 3', slug='site-group-3', description='foobar3'), SiteGroup(name='Site Group 3', slug='site-group-3', description='foobar3'),
) )
for sitegroup in sitegroups: for site_group in parent_groups:
sitegroup.save() site_group.save()
child_sitegroups = ( groups = (
SiteGroup(name='Site Group 1A', slug='site-group-1a', parent=sitegroups[0]), SiteGroup(name='Site Group 1A', slug='site-group-1a', parent=parent_groups[0]),
SiteGroup(name='Site Group 1B', slug='site-group-1b', parent=sitegroups[0]), SiteGroup(name='Site Group 1B', slug='site-group-1b', parent=parent_groups[0]),
SiteGroup(name='Site Group 2A', slug='site-group-2a', parent=sitegroups[1]), SiteGroup(name='Site Group 2A', slug='site-group-2a', parent=parent_groups[1]),
SiteGroup(name='Site Group 2B', slug='site-group-2b', parent=sitegroups[1]), SiteGroup(name='Site Group 2B', slug='site-group-2b', parent=parent_groups[1]),
SiteGroup(name='Site Group 3A', slug='site-group-3a', parent=sitegroups[2]), SiteGroup(name='Site Group 3A', slug='site-group-3a', parent=parent_groups[2]),
SiteGroup(name='Site Group 3B', slug='site-group-3b', parent=sitegroups[2]), SiteGroup(name='Site Group 3B', slug='site-group-3b', parent=parent_groups[2]),
) )
for sitegroup in child_sitegroups: for site_group in groups:
sitegroup.save() site_group.save()
child_groups = (
SiteGroup(name='Site Group 1A1', slug='site-group-1a1', parent=groups[0]),
SiteGroup(name='Site Group 1B1', slug='site-group-1b1', parent=groups[1]),
SiteGroup(name='Site Group 2A1', slug='site-group-2a1', parent=groups[2]),
SiteGroup(name='Site Group 2B1', slug='site-group-2b1', parent=groups[3]),
SiteGroup(name='Site Group 3A1', slug='site-group-3a1', parent=groups[4]),
SiteGroup(name='Site Group 3B1', slug='site-group-3b1', parent=groups[5]),
)
for site_group in child_groups:
site_group.save()
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
@ -150,12 +179,19 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_parent(self): def test_parent(self):
parent_sitegroups = SiteGroup.objects.filter(parent__isnull=True)[:2] site_groups = SiteGroup.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [parent_sitegroups[0].pk, parent_sitegroups[1].pk]} params = {'parent_id': [site_groups[0].pk, site_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'parent': [parent_sitegroups[0].slug, parent_sitegroups[1].slug]} params = {'parent': [site_groups[0].slug, site_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_ancestor(self):
site_groups = SiteGroup.objects.filter(parent__isnull=True)[:2]
params = {'ancestor_id': [site_groups[0].pk, site_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
params = {'ancestor': [site_groups[0].slug, site_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
class SiteTestCase(TestCase, ChangeLoggedFilterSetTests): class SiteTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Site.objects.all() queryset = Site.objects.all()
@ -314,21 +350,29 @@ class LocationTestCase(TestCase, ChangeLoggedFilterSetTests):
Site.objects.bulk_create(sites) Site.objects.bulk_create(sites)
parent_locations = ( parent_locations = (
Location(name='Parent Location 1', slug='parent-location-1', site=sites[0]), Location(name='Location 1', slug='location-1', site=sites[0]),
Location(name='Parent Location 2', slug='parent-location-2', site=sites[1]), Location(name='Location 2', slug='location-2', site=sites[1]),
Location(name='Parent Location 3', slug='parent-location-3', site=sites[2]), Location(name='Location 3', slug='location-3', site=sites[2]),
) )
for location in parent_locations: for location in parent_locations:
location.save() location.save()
locations = ( locations = (
Location(name='Location 1', slug='location-1', site=sites[0], parent=parent_locations[0], status=LocationStatusChoices.STATUS_PLANNED, description='foobar1'), Location(name='Location 1A', slug='location-1a', site=sites[0], parent=parent_locations[0], status=LocationStatusChoices.STATUS_PLANNED, description='foobar1'),
Location(name='Location 2', slug='location-2', site=sites[1], parent=parent_locations[1], status=LocationStatusChoices.STATUS_STAGING, description='foobar2'), Location(name='Location 2A', slug='location-2a', site=sites[1], parent=parent_locations[1], status=LocationStatusChoices.STATUS_STAGING, description='foobar2'),
Location(name='Location 3', slug='location-3', site=sites[2], parent=parent_locations[2], status=LocationStatusChoices.STATUS_DECOMMISSIONING, description='foobar3'), Location(name='Location 3A', slug='location-3a', site=sites[2], parent=parent_locations[2], status=LocationStatusChoices.STATUS_DECOMMISSIONING, description='foobar3'),
) )
for location in locations: for location in locations:
location.save() location.save()
child_locations = (
Location(name='Location 1A1', slug='location-1a1', site=sites[0], parent=locations[0]),
Location(name='Location 2A1', slug='location-2a1', site=sites[1], parent=locations[1]),
Location(name='Location 3A1', slug='location-3a1', site=sites[2], parent=locations[2]),
)
for location in child_locations:
location.save()
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@ -352,31 +396,38 @@ class LocationTestCase(TestCase, ChangeLoggedFilterSetTests):
def test_region(self): def test_region(self):
regions = Region.objects.all()[:2] regions = Region.objects.all()[:2]
params = {'region_id': [regions[0].pk, regions[1].pk]} params = {'region_id': [regions[0].pk, regions[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
params = {'region': [regions[0].slug, regions[1].slug]} params = {'region': [regions[0].slug, regions[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
def test_site_group(self): def test_site_group(self):
site_groups = SiteGroup.objects.all()[:2] site_groups = SiteGroup.objects.all()[:2]
params = {'site_group_id': [site_groups[0].pk, site_groups[1].pk]} params = {'site_group_id': [site_groups[0].pk, site_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
params = {'site_group': [site_groups[0].slug, site_groups[1].slug]} params = {'site_group': [site_groups[0].slug, site_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
def test_site(self): def test_site(self):
sites = Site.objects.all()[:2] sites = Site.objects.all()[:2]
params = {'site_id': [sites[0].pk, sites[1].pk]} params = {'site_id': [sites[0].pk, sites[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
params = {'site': [sites[0].slug, sites[1].slug]} params = {'site': [sites[0].slug, sites[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
def test_parent(self): def test_parent(self):
parent_groups = Location.objects.filter(name__startswith='Parent')[:2] locations = Location.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} params = {'parent_id': [locations[0].pk, locations[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} params = {'parent': [locations[0].slug, locations[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ancestor(self):
locations = Location.objects.filter(parent__isnull=True)[:2]
params = {'ancestor_id': [locations[0].pk, locations[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'ancestor': [locations[0].slug, locations[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
class RackRoleTestCase(TestCase, ChangeLoggedFilterSetTests): class RackRoleTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = RackRole.objects.all() queryset = RackRole.objects.all()

View File

@ -26,12 +26,25 @@ __all__ = (
class ContactGroupFilterSet(OrganizationalModelFilterSet): class ContactGroupFilterSet(OrganizationalModelFilterSet):
parent_id = django_filters.ModelMultipleChoiceFilter( parent_id = django_filters.ModelMultipleChoiceFilter(
queryset=ContactGroup.objects.all(), queryset=ContactGroup.objects.all(),
label=_('Contact group (ID)'), label=_('Parent contact group (ID)'),
) )
parent = django_filters.ModelMultipleChoiceFilter( parent = django_filters.ModelMultipleChoiceFilter(
field_name='parent__slug', field_name='parent__slug',
queryset=ContactGroup.objects.all(), queryset=ContactGroup.objects.all(),
to_field_name='slug', to_field_name='slug',
label=_('Parent contact group (slug)'),
)
ancestor_id = TreeNodeMultipleChoiceFilter(
queryset=ContactGroup.objects.all(),
field_name='parent',
lookup_expr='in',
label=_('Contact group (ID)'),
)
ancestor = TreeNodeMultipleChoiceFilter(
queryset=ContactGroup.objects.all(),
field_name='parent',
lookup_expr='in',
to_field_name='slug',
label=_('Contact group (slug)'), label=_('Contact group (slug)'),
) )
@ -155,12 +168,25 @@ class ContactModelFilterSet(django_filters.FilterSet):
class TenantGroupFilterSet(OrganizationalModelFilterSet): class TenantGroupFilterSet(OrganizationalModelFilterSet):
parent_id = django_filters.ModelMultipleChoiceFilter( parent_id = django_filters.ModelMultipleChoiceFilter(
queryset=TenantGroup.objects.all(), queryset=TenantGroup.objects.all(),
label=_('Tenant group (ID)'), label=_('Parent tenant group (ID)'),
) )
parent = django_filters.ModelMultipleChoiceFilter( parent = django_filters.ModelMultipleChoiceFilter(
field_name='parent__slug', field_name='parent__slug',
queryset=TenantGroup.objects.all(), queryset=TenantGroup.objects.all(),
to_field_name='slug', to_field_name='slug',
label=_('Parent tenant group (slug)'),
)
ancestor_id = TreeNodeMultipleChoiceFilter(
queryset=TenantGroup.objects.all(),
field_name='parent',
lookup_expr='in',
label=_('Tenant group (ID)'),
)
ancestor = TreeNodeMultipleChoiceFilter(
queryset=TenantGroup.objects.all(),
field_name='parent',
lookup_expr='in',
to_field_name='slug',
label=_('Tenant group (slug)'), label=_('Tenant group (slug)'),
) )

View File

@ -15,35 +15,43 @@ class TenantGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
def setUpTestData(cls): def setUpTestData(cls):
parent_tenant_groups = ( parent_tenant_groups = (
TenantGroup(name='Parent Tenant Group 1', slug='parent-tenant-group-1'), TenantGroup(name='Tenant Group 1', slug='tenant-group-1'),
TenantGroup(name='Parent Tenant Group 2', slug='parent-tenant-group-2'), TenantGroup(name='Tenant Group 2', slug='tenant-group-2'),
TenantGroup(name='Parent Tenant Group 3', slug='parent-tenant-group-3'), TenantGroup(name='Tenant Group 3', slug='tenant-group-3'),
) )
for tenantgroup in parent_tenant_groups: for tenant_group in parent_tenant_groups:
tenantgroup.save() tenant_group.save()
tenant_groups = ( tenant_groups = (
TenantGroup( TenantGroup(
name='Tenant Group 1', name='Tenant Group 1A',
slug='tenant-group-1', slug='tenant-group-1a',
parent=parent_tenant_groups[0], parent=parent_tenant_groups[0],
description='foobar1' description='foobar1'
), ),
TenantGroup( TenantGroup(
name='Tenant Group 2', name='Tenant Group 2A',
slug='tenant-group-2', slug='tenant-group-2a',
parent=parent_tenant_groups[1], parent=parent_tenant_groups[1],
description='foobar2' description='foobar2'
), ),
TenantGroup( TenantGroup(
name='Tenant Group 3', name='Tenant Group 3A',
slug='tenant-group-3', slug='tenant-group-3a',
parent=parent_tenant_groups[2], parent=parent_tenant_groups[2],
description='foobar3' description='foobar3'
), ),
) )
for tenantgroup in tenant_groups: for tenant_group in tenant_groups:
tenantgroup.save() tenant_group.save()
child_tenant_groups = (
TenantGroup(name='Tenant Group 1A1', slug='tenant-group-1a1', parent=tenant_groups[0]),
TenantGroup(name='Tenant Group 2A1', slug='tenant-group-2a1', parent=tenant_groups[1]),
TenantGroup(name='Tenant Group 3A1', slug='tenant-group-3a1', parent=tenant_groups[2]),
)
for tenant_group in child_tenant_groups:
tenant_group.save()
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
@ -62,12 +70,19 @@ class TenantGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_parent(self): def test_parent(self):
parent_groups = TenantGroup.objects.filter(name__startswith='Parent')[:2] tenant_groups = TenantGroup.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} params = {'parent_id': [tenant_groups[0].pk, tenant_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} params = {'parent': [tenant_groups[0].slug, tenant_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ancestor(self):
tenant_groups = TenantGroup.objects.filter(parent__isnull=True)[:2]
params = {'ancestor_id': [tenant_groups[0].pk, tenant_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'ancestor': [tenant_groups[0].slug, tenant_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
class TenantTestCase(TestCase, ChangeLoggedFilterSetTests): class TenantTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = Tenant.objects.all() queryset = Tenant.objects.all()
@ -123,35 +138,43 @@ class ContactGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
def setUpTestData(cls): def setUpTestData(cls):
parent_contact_groups = ( parent_contact_groups = (
ContactGroup(name='Parent Contact Group 1', slug='parent-contact-group-1'), ContactGroup(name='Contact Group 1', slug='contact-group-1'),
ContactGroup(name='Parent Contact Group 2', slug='parent-contact-group-2'), ContactGroup(name='Contact Group 2', slug='contact-group-2'),
ContactGroup(name='Parent Contact Group 3', slug='parent-contact-group-3'), ContactGroup(name='Contact Group 3', slug='contact-group-3'),
) )
for contactgroup in parent_contact_groups: for contact_group in parent_contact_groups:
contactgroup.save() contact_group.save()
contact_groups = ( contact_groups = (
ContactGroup( ContactGroup(
name='Contact Group 1', name='Contact Group 1A',
slug='contact-group-1', slug='contact-group-1a',
parent=parent_contact_groups[0], parent=parent_contact_groups[0],
description='foobar1' description='foobar1'
), ),
ContactGroup( ContactGroup(
name='Contact Group 2', name='Contact Group 2A',
slug='contact-group-2', slug='contact-group-2a',
parent=parent_contact_groups[1], parent=parent_contact_groups[1],
description='foobar2' description='foobar2'
), ),
ContactGroup( ContactGroup(
name='Contact Group 3', name='Contact Group 3A',
slug='contact-group-3', slug='contact-group-3a',
parent=parent_contact_groups[2], parent=parent_contact_groups[2],
description='foobar3' description='foobar3'
), ),
) )
for contactgroup in contact_groups: for contact_group in contact_groups:
contactgroup.save() contact_group.save()
child_contact_groups = (
ContactGroup(name='Contact Group 1A1', slug='contact-group-1a1', parent=contact_groups[0]),
ContactGroup(name='Contact Group 2A1', slug='contact-group-2a1', parent=contact_groups[1]),
ContactGroup(name='Contact Group 3A1', slug='contact-group-3a1', parent=contact_groups[2]),
)
for contact_group in child_contact_groups:
contact_group.save()
def test_q(self): def test_q(self):
params = {'q': 'foobar1'} params = {'q': 'foobar1'}
@ -170,12 +193,19 @@ class ContactGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_parent(self): def test_parent(self):
parent_groups = ContactGroup.objects.filter(parent__isnull=True)[:2] contact_groups = ContactGroup.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]} params = {'parent_id': [contact_groups[0].pk, contact_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]} params = {'parent': [contact_groups[0].slug, contact_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_ancestor(self):
contact_groups = ContactGroup.objects.filter(parent__isnull=True)[:2]
params = {'ancestor_id': [contact_groups[0].pk, contact_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'ancestor': [contact_groups[0].slug, contact_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
class ContactRoleTestCase(TestCase, ChangeLoggedFilterSetTests): class ContactRoleTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = ContactRole.objects.all() queryset = ContactRole.objects.all()

View File

@ -25,6 +25,17 @@ class WirelessLANGroupFilterSet(OrganizationalModelFilterSet):
queryset=WirelessLANGroup.objects.all(), queryset=WirelessLANGroup.objects.all(),
to_field_name='slug' to_field_name='slug'
) )
ancestor_id = TreeNodeMultipleChoiceFilter(
queryset=WirelessLANGroup.objects.all(),
field_name='parent',
lookup_expr='in'
)
ancestor = TreeNodeMultipleChoiceFilter(
queryset=WirelessLANGroup.objects.all(),
field_name='parent',
lookup_expr='in',
to_field_name='slug'
)
class Meta: class Meta:
model = WirelessLANGroup model = WirelessLANGroup

View File

@ -17,21 +17,32 @@ class WirelessLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
groups = ( parent_groups = (
WirelessLANGroup(name='Wireless LAN Group 1', slug='wireless-lan-group-1', description='A'), WirelessLANGroup(name='Wireless LAN Group 1', slug='wireless-lan-group-1', description='A'),
WirelessLANGroup(name='Wireless LAN Group 2', slug='wireless-lan-group-2', description='B'), WirelessLANGroup(name='Wireless LAN Group 2', slug='wireless-lan-group-2', description='B'),
WirelessLANGroup(name='Wireless LAN Group 3', slug='wireless-lan-group-3', description='C'), WirelessLANGroup(name='Wireless LAN Group 3', slug='wireless-lan-group-3', description='C'),
) )
for group in parent_groups:
group.save()
groups = (
WirelessLANGroup(name='Wireless LAN Group 1A', slug='wireless-lan-group-1a', parent=parent_groups[0], description='foobar1'),
WirelessLANGroup(name='Wireless LAN Group 1B', slug='wireless-lan-group-1b', parent=parent_groups[0], description='foobar2'),
WirelessLANGroup(name='Wireless LAN Group 2A', slug='wireless-lan-group-2a', parent=parent_groups[1]),
WirelessLANGroup(name='Wireless LAN Group 2B', slug='wireless-lan-group-2b', parent=parent_groups[1]),
WirelessLANGroup(name='Wireless LAN Group 3A', slug='wireless-lan-group-3a', parent=parent_groups[2]),
WirelessLANGroup(name='Wireless LAN Group 3B', slug='wireless-lan-group-3b', parent=parent_groups[2]),
)
for group in groups: for group in groups:
group.save() group.save()
child_groups = ( child_groups = (
WirelessLANGroup(name='Wireless LAN Group 1A', slug='wireless-lan-group-1a', parent=groups[0], description='foobar1'), WirelessLANGroup(name='Wireless LAN Group 1A1', slug='wireless-lan-group-1a1', parent=groups[0]),
WirelessLANGroup(name='Wireless LAN Group 1B', slug='wireless-lan-group-1b', parent=groups[0], description='foobar2'), WirelessLANGroup(name='Wireless LAN Group 1B1', slug='wireless-lan-group-1b1', parent=groups[1]),
WirelessLANGroup(name='Wireless LAN Group 2A', slug='wireless-lan-group-2a', parent=groups[1]), WirelessLANGroup(name='Wireless LAN Group 2A1', slug='wireless-lan-group-2a1', parent=groups[2]),
WirelessLANGroup(name='Wireless LAN Group 2B', slug='wireless-lan-group-2b', parent=groups[1]), WirelessLANGroup(name='Wireless LAN Group 2B1', slug='wireless-lan-group-2b1', parent=groups[3]),
WirelessLANGroup(name='Wireless LAN Group 3A', slug='wireless-lan-group-3a', parent=groups[2]), WirelessLANGroup(name='Wireless LAN Group 3A1', slug='wireless-lan-group-3a1', parent=groups[4]),
WirelessLANGroup(name='Wireless LAN Group 3B', slug='wireless-lan-group-3b', parent=groups[2]), WirelessLANGroup(name='Wireless LAN Group 3B1', slug='wireless-lan-group-3b1', parent=groups[5]),
) )
for group in child_groups: for group in child_groups:
group.save() group.save()
@ -48,17 +59,24 @@ class WirelessLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
params = {'slug': ['wireless-lan-group-1', 'wireless-lan-group-2']} params = {'slug': ['wireless-lan-group-1', 'wireless-lan-group-2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_parent(self):
parent_groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_description(self): def test_description(self):
params = {'description': ['foobar1', 'foobar2']} params = {'description': ['foobar1', 'foobar2']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_parent(self):
groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2]
params = {'parent_id': [groups[0].pk, groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
params = {'parent': [groups[0].slug, groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_ancestor(self):
groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2]
params = {'ancestor_id': [groups[0].pk, groups[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
params = {'ancestor': [groups[0].slug, groups[1].slug]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests): class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = WirelessLAN.objects.all() queryset = WirelessLAN.objects.all()