From 766b5dff24b69a370140d44f1ac8cf941e75b484 Mon Sep 17 00:00:00 2001 From: kobayashi Date: Wed, 16 Oct 2019 00:32:54 -0400 Subject: [PATCH 1/2] allow null region filtering --- netbox/dcim/tests/test_api.py | 28 ++++++++++++++++++++++++- netbox/utilities/filters.py | 9 ++++++-- netbox/virtualization/tests/test_api.py | 25 +++++++++++++++------- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index 9c873c886..44d9cd414 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -174,6 +174,16 @@ class SiteTest(APITestCase): ['id', 'name', 'slug', 'url'] ) + def test_list_sites_null_region(self): + + Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1') + Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2') + + url = reverse('dcim-api:site-list') + response = self.client.get('{}?region=null'.format(url), **self.header) + + self.assertEqual(response.data['count'], 2) + def test_create_site(self): data = { @@ -1753,7 +1763,8 @@ class DeviceTest(APITestCase): super().setUp() - self.site1 = Site.objects.create(name='Test Site 1', slug='test-site-1') + region = Region.objects.create(name='Test Region', slug='test-region') + self.site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1') self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2') manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1') self.devicetype1 = DeviceType.objects.create( @@ -1828,6 +1839,21 @@ class DeviceTest(APITestCase): ['display_name', 'id', 'name', 'url'] ) + def test_list_device_null_region(self): + + Device.objects.create( + device_type=self.devicetype1, + device_role=self.devicerole1, + name='Test Device Null Region', + site=self.site2, + cluster=self.cluster1 + ) + + url = reverse('dcim-api:device-list') + response = self.client.get('{}?region=null'.format(url), **self.header) + + self.assertEqual(response.data['count'], 1) + def test_create_device(self): data = { diff --git a/netbox/utilities/filters.py b/netbox/utilities/filters.py index 7ba008c70..e13f095f9 100644 --- a/netbox/utilities/filters.py +++ b/netbox/utilities/filters.py @@ -1,9 +1,8 @@ import django_filters +from dcim.forms import MACAddressField from django import forms from django.conf import settings from django.db import models - -from dcim.forms import MACAddressField from extras.models import Tag @@ -62,7 +61,13 @@ class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter): """ Filters for a set of Models, including all descendant models within a Tree. Example: [,] """ + def filter(self, qs, value): + if settings.FILTERS_NULL_CHOICE_VALUE in value: + # Filtering by null value. Example: region=null + qs = self.get_method(qs)(**{self.field_name.replace('in', 'isnull'): True}) + return qs.distinct() if self.distinct else qs + value = [node.get_descendants(include_self=True) for node in value] return super().filter(qs, value) diff --git a/netbox/virtualization/tests/test_api.py b/netbox/virtualization/tests/test_api.py index f1e372dd4..e81b14ebb 100644 --- a/netbox/virtualization/tests/test_api.py +++ b/netbox/virtualization/tests/test_api.py @@ -3,7 +3,7 @@ from netaddr import IPNetwork from rest_framework import status from dcim.constants import IFACE_TYPE_VIRTUAL, IFACE_MODE_TAGGED -from dcim.models import Interface +from dcim.models import Interface, Region, Site from ipam.models import IPAddress, VLAN from utilities.testing import APITestCase from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine @@ -330,9 +330,14 @@ class VirtualMachineTest(APITestCase): super().setUp() + region = Region.objects.create(name='Test Region 1', slug='test-region-1') + site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1') + site2 = Site.objects.create(name='Test Site 2', slug='test-site-2') + cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1') cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1') - self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group) + self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group, site=site1) + self.cluster2 = Cluster.objects.create(name='Test Cluster 2', type=cluster_type, group=cluster_group, site=site2) self.virtualmachine1 = VirtualMachine.objects.create(name='Test Virtual Machine 1', cluster=self.cluster1) self.virtualmachine2 = VirtualMachine.objects.create(name='Test Virtual Machine 2', cluster=self.cluster1) @@ -370,6 +375,15 @@ class VirtualMachineTest(APITestCase): ['id', 'name', 'url'] ) + def test_list_virtualmachines_null_region(self): + + VirtualMachine.objects.create(name='Test Virtual Machine Null Region', cluster=self.cluster2) + + url = reverse('virtualization-api:virtualmachine-list') + response = self.client.get('{}?region=null'.format(url), **self.header) + + self.assertEqual(response.data['count'], 1) + def test_create_virtualmachine(self): data = { @@ -430,14 +444,9 @@ class VirtualMachineTest(APITestCase): ip4_address = IPAddress.objects.create(address=IPNetwork('192.0.2.1/24'), interface=interface) ip6_address = IPAddress.objects.create(address=IPNetwork('2001:db8::1/64'), interface=interface) - cluster2 = Cluster.objects.create( - name='Test Cluster 2', - type=ClusterType.objects.first(), - group=ClusterGroup.objects.first() - ) data = { 'name': 'Test Virtual Machine X', - 'cluster': cluster2.pk, + 'cluster': self.cluster2.pk, 'primary_ip4': ip4_address.pk, 'primary_ip6': ip6_address.pk, } From d2aa9b8e79592e96cef8a94f60131cb17fe065df Mon Sep 17 00:00:00 2001 From: kobayashi Date: Mon, 28 Oct 2019 02:24:44 -0400 Subject: [PATCH 2/2] filtering multiple regions with null --- netbox/dcim/tests/test_api.py | 69 ++++++++++++++++--------- netbox/utilities/filters.py | 13 ++--- netbox/virtualization/tests/test_api.py | 25 +++++---- 3 files changed, 67 insertions(+), 40 deletions(-) diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index 44d9cd414..4b711efe4 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -127,7 +127,9 @@ class SiteTest(APITestCase): self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2') self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site-1') self.site2 = Site.objects.create(region=self.region1, name='Test Site 2', slug='test-site-2') - self.site3 = Site.objects.create(region=self.region1, name='Test Site 3', slug='test-site-3') + self.site3 = Site.objects.create(region=self.region2, name='Test Site 3', slug='test-site-3') + self.site_non_region1 = Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1') + self.site_non_region2 = Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2') def test_get_site(self): @@ -162,7 +164,7 @@ class SiteTest(APITestCase): url = reverse('dcim-api:site-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 3) + self.assertEqual(response.data['count'], 5) def test_list_sites_brief(self): @@ -176,14 +178,18 @@ class SiteTest(APITestCase): def test_list_sites_null_region(self): - Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1') - Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2') - url = reverse('dcim-api:site-list') response = self.client.get('{}?region=null'.format(url), **self.header) self.assertEqual(response.data['count'], 2) + def test_list_sites_multiple_regions(self): + + url = reverse('dcim-api:site-list') + response = self.client.get('{}?region=null®ion=test-region-1'.format(url), **self.header) + + self.assertEqual(response.data['count'], 4) + def test_create_site(self): data = { @@ -197,7 +203,7 @@ class SiteTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Site.objects.count(), 4) + self.assertEqual(Site.objects.count(), 6) site4 = Site.objects.get(pk=response.data['id']) self.assertEqual(site4.name, data['name']) self.assertEqual(site4.slug, data['slug']) @@ -230,7 +236,7 @@ class SiteTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Site.objects.count(), 6) + self.assertEqual(Site.objects.count(), 8) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) @@ -247,7 +253,7 @@ class SiteTest(APITestCase): response = self.client.put(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Site.objects.count(), 3) + self.assertEqual(Site.objects.count(), 5) site1 = Site.objects.get(pk=response.data['id']) self.assertEqual(site1.name, data['name']) self.assertEqual(site1.slug, data['slug']) @@ -259,7 +265,7 @@ class SiteTest(APITestCase): response = self.client.delete(url, **self.header) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Site.objects.count(), 2) + self.assertEqual(Site.objects.count(), 4) class RackGroupTest(APITestCase): @@ -1763,7 +1769,7 @@ class DeviceTest(APITestCase): super().setUp() - region = Region.objects.create(name='Test Region', slug='test-region') + region = Region.objects.create(name='Test Region', slug='test-region-1') self.site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1') self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2') manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1') @@ -1812,6 +1818,20 @@ class DeviceTest(APITestCase): 'B': 2 } ) + self.device_non_region1 = Device.objects.create( + device_type=self.devicetype1, + device_role=self.devicerole1, + name='Test Device Null Region1', + site=self.site2, + cluster=self.cluster1 + ) + self.device_non_region2 = Device.objects.create( + device_type=self.devicetype1, + device_role=self.devicerole1, + name='Test Device Null Region2', + site=self.site2, + cluster=self.cluster1 + ) def test_get_device(self): @@ -1827,7 +1847,7 @@ class DeviceTest(APITestCase): url = reverse('dcim-api:device-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 4) + self.assertEqual(response.data['count'], 6) def test_list_devices_brief(self): @@ -1839,20 +1859,19 @@ class DeviceTest(APITestCase): ['display_name', 'id', 'name', 'url'] ) - def test_list_device_null_region(self): - - Device.objects.create( - device_type=self.devicetype1, - device_role=self.devicerole1, - name='Test Device Null Region', - site=self.site2, - cluster=self.cluster1 - ) + def test_list_devices_null_region(self): url = reverse('dcim-api:device-list') response = self.client.get('{}?region=null'.format(url), **self.header) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data['count'], 2) + + def test_list_devices_multiple_regions(self): + + url = reverse('dcim-api:device-list') + response = self.client.get('{}?region=null®ion=test-region-1'.format(url), **self.header) + + self.assertEqual(response.data['count'], 6) def test_create_device(self): @@ -1868,7 +1887,7 @@ class DeviceTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Device.objects.count(), 5) + self.assertEqual(Device.objects.count(), 7) device4 = Device.objects.get(pk=response.data['id']) self.assertEqual(device4.device_type_id, data['device_type']) self.assertEqual(device4.device_role_id, data['device_role']) @@ -1903,7 +1922,7 @@ class DeviceTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Device.objects.count(), 7) + self.assertEqual(Device.objects.count(), 9) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) @@ -1927,7 +1946,7 @@ class DeviceTest(APITestCase): response = self.client.put(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Device.objects.count(), 4) + self.assertEqual(Device.objects.count(), 6) device1 = Device.objects.get(pk=response.data['id']) self.assertEqual(device1.device_type_id, data['device_type']) self.assertEqual(device1.device_role_id, data['device_role']) @@ -1942,7 +1961,7 @@ class DeviceTest(APITestCase): response = self.client.delete(url, **self.header) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Device.objects.count(), 3) + self.assertEqual(Device.objects.count(), 5) def test_config_context_included_by_default_in_list_view(self): diff --git a/netbox/utilities/filters.py b/netbox/utilities/filters.py index e13f095f9..957020e40 100644 --- a/netbox/utilities/filters.py +++ b/netbox/utilities/filters.py @@ -62,13 +62,14 @@ class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter): Filters for a set of Models, including all descendant models within a Tree. Example: [,] """ - def filter(self, qs, value): - if settings.FILTERS_NULL_CHOICE_VALUE in value: - # Filtering by null value. Example: region=null - qs = self.get_method(qs)(**{self.field_name.replace('in', 'isnull'): True}) - return qs.distinct() if self.distinct else qs + def get_filter_predicate(self, v): + # null value filtering + if v is None: + return {self.field_name.replace('in', 'isnull'): True} + return super().get_filter_predicate(v) - value = [node.get_descendants(include_self=True) for node in value] + def filter(self, qs, value): + value = [node.get_descendants(include_self=True) if not isinstance(node, str) else node for node in value] return super().filter(qs, value) diff --git a/netbox/virtualization/tests/test_api.py b/netbox/virtualization/tests/test_api.py index e81b14ebb..7bbeccbdd 100644 --- a/netbox/virtualization/tests/test_api.py +++ b/netbox/virtualization/tests/test_api.py @@ -350,6 +350,8 @@ class VirtualMachineTest(APITestCase): 'B': 2 } ) + self.virtualmachine_non_region1 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region1', cluster=self.cluster2) + self.virtualmachine_non_region2 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region2', cluster=self.cluster2) def test_get_virtualmachine(self): @@ -363,7 +365,7 @@ class VirtualMachineTest(APITestCase): url = reverse('virtualization-api:virtualmachine-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 4) + self.assertEqual(response.data['count'], 6) def test_list_virtualmachines_brief(self): @@ -377,12 +379,17 @@ class VirtualMachineTest(APITestCase): def test_list_virtualmachines_null_region(self): - VirtualMachine.objects.create(name='Test Virtual Machine Null Region', cluster=self.cluster2) - url = reverse('virtualization-api:virtualmachine-list') response = self.client.get('{}?region=null'.format(url), **self.header) - self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data['count'], 2) + + def test_list_virtualmachines_multiple_regions(self): + + url = reverse('virtualization-api:virtualmachine-list') + response = self.client.get('{}?region=null®ion=test-region-1'.format(url), **self.header) + + self.assertEqual(response.data['count'], 6) def test_create_virtualmachine(self): @@ -395,7 +402,7 @@ class VirtualMachineTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(VirtualMachine.objects.count(), 5) + self.assertEqual(VirtualMachine.objects.count(), 7) virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id']) self.assertEqual(virtualmachine4.name, data['name']) self.assertEqual(virtualmachine4.cluster.pk, data['cluster']) @@ -410,7 +417,7 @@ class VirtualMachineTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) - self.assertEqual(VirtualMachine.objects.count(), 4) + self.assertEqual(VirtualMachine.objects.count(), 6) def test_create_virtualmachine_bulk(self): @@ -433,7 +440,7 @@ class VirtualMachineTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(VirtualMachine.objects.count(), 7) + self.assertEqual(VirtualMachine.objects.count(), 9) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) @@ -455,7 +462,7 @@ class VirtualMachineTest(APITestCase): response = self.client.put(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(VirtualMachine.objects.count(), 4) + self.assertEqual(VirtualMachine.objects.count(), 6) virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id']) self.assertEqual(virtualmachine1.name, data['name']) self.assertEqual(virtualmachine1.cluster.pk, data['cluster']) @@ -468,7 +475,7 @@ class VirtualMachineTest(APITestCase): response = self.client.delete(url, **self.header) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(VirtualMachine.objects.count(), 3) + self.assertEqual(VirtualMachine.objects.count(), 5) def test_config_context_included_by_default_in_list_view(self):