From a47a100cb7a761e0d1064ff5f0f9f8e877e4b24b Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Mon, 29 Jun 2020 13:30:41 -0400 Subject: [PATCH] Fix unrestricted evaluations of RestrictedQuerySet --- netbox/dcim/api/views.py | 2 +- netbox/dcim/models/__init__.py | 49 ++++++++++++++--------- netbox/dcim/tests/test_api.py | 72 +++++++++++++++++----------------- 3 files changed, 66 insertions(+), 57 deletions(-) diff --git a/netbox/dcim/api/views.py b/netbox/dcim/api/views.py index eb48d40d9..2a3619a2f 100644 --- a/netbox/dcim/api/views.py +++ b/netbox/dcim/api/views.py @@ -226,7 +226,7 @@ class ManufacturerViewSet(ModelViewSet): # class DeviceTypeViewSet(CustomFieldModelViewSet): - queryset = DeviceType.objects.prefetch_related('manufacturer').prefetch_related('tags').annotate( + queryset = DeviceType.objects.prefetch_related('manufacturer', 'tags').annotate( device_count=Count('instances') ) serializer_class = serializers.DeviceTypeSerializer diff --git a/netbox/dcim/models/__init__.py b/netbox/dcim/models/__init__.py index f930ae02d..13e80c60a 100644 --- a/netbox/dcim/models/__init__.py +++ b/netbox/dcim/models/__init__.py @@ -580,7 +580,11 @@ class Rack(ChangeLoggedModel, CustomFieldModel): if self.pk: # Validate that Rack is tall enough to house the installed Devices - top_device = Device.objects.filter(rack=self).exclude(position__isnull=True).order_by('-position').first() + top_device = Device.objects.unrestricted().filter( + rack=self + ).exclude( + position__isnull=True + ).order_by('-position').first() if top_device: min_height = top_device.position + top_device.device_type.u_height - 1 if self.u_height < min_height: @@ -601,13 +605,13 @@ class Rack(ChangeLoggedModel, CustomFieldModel): # Record the original site assignment for this rack. _site_id = None if self.pk: - _site_id = Rack.objects.get(pk=self.pk).site_id + _site_id = Rack.objects.unrestricted().get(pk=self.pk).site_id super().save(*args, **kwargs) # Update racked devices if the assigned Site has been changed. if _site_id is not None and self.site_id != _site_id: - devices = Device.objects.filter(rack=self) + devices = Device.objects.unrestricted().filter(rack=self) for device in devices: device.site = self.site device.save() @@ -1125,7 +1129,7 @@ class DeviceType(ChangeLoggedModel, CustomFieldModel): # room to expand within their racks. This validation will impose a very high performance penalty when there are # many instances to check, but increasing the u_height of a DeviceType should be a very rare occurrence. if self.pk and self.u_height > self._original_u_height: - for d in Device.objects.filter(device_type=self, position__isnull=False): + for d in Device.objects.unrestricted().filter(device_type=self, position__isnull=False): face_required = None if self.is_full_depth else d.face u_available = d.rack.get_available_units( u_height=self.u_height, @@ -1140,7 +1144,10 @@ class DeviceType(ChangeLoggedModel, CustomFieldModel): # If modifying the height of an existing DeviceType to 0U, check for any instances assigned to a rack position. elif self.pk and self._original_u_height > 0 and self.u_height == 0: - racked_instance_count = Device.objects.filter(device_type=self, position__isnull=False).count() + racked_instance_count = Device.objects.unrestricted().filter( + device_type=self, + position__isnull=False + ).count() if racked_instance_count: url = f"{reverse('dcim:device_list')}?manufactuer_id={self.manufacturer_id}&device_type_id={self.pk}" raise ValidationError({ @@ -1493,7 +1500,11 @@ class Device(ChangeLoggedModel, ConfigContextModel, CustomFieldModel): # because Django does not consider two NULL fields to be equal, and thus will not trigger a violation # of the uniqueness constraint without manual intervention. if self.name and self.tenant is None: - if Device.objects.exclude(pk=self.pk).filter(name=self.name, site=self.site, tenant__isnull=True): + if Device.objects.unrestricted().exclude(pk=self.pk).filter( + name=self.name, + site=self.site, + tenant__isnull=True + ): raise ValidationError({ 'name': 'A device with this name already exists.' }) @@ -1623,32 +1634,32 @@ class Device(ChangeLoggedModel, ConfigContextModel, CustomFieldModel): # If this is a new Device, instantiate all of the related components per the DeviceType definition if is_new: ConsolePort.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.consoleport_templates.all()] + [x.instantiate(self) for x in self.device_type.consoleport_templates.unrestricted()] ) ConsoleServerPort.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.consoleserverport_templates.all()] + [x.instantiate(self) for x in self.device_type.consoleserverport_templates.unrestricted()] ) PowerPort.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.powerport_templates.all()] + [x.instantiate(self) for x in self.device_type.powerport_templates.unrestricted()] ) PowerOutlet.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.poweroutlet_templates.all()] + [x.instantiate(self) for x in self.device_type.poweroutlet_templates.unrestricted()] ) Interface.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.interface_templates.all()] + [x.instantiate(self) for x in self.device_type.interface_templates.unrestricted()] ) RearPort.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.rearport_templates.all()] + [x.instantiate(self) for x in self.device_type.rearport_templates.unrestricted()] ) FrontPort.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.frontport_templates.all()] + [x.instantiate(self) for x in self.device_type.frontport_templates.unrestricted()] ) DeviceBay.objects.bulk_create( - [x.instantiate(self) for x in self.device_type.device_bay_templates.all()] + [x.instantiate(self) for x in self.device_type.device_bay_templates.unrestricted()] ) # Update Site and Rack assignment for any child Devices - devices = Device.objects.filter(parent_bay__device=self) + devices = Device.objects.unrestricted().filter(parent_bay__device=self) for device in devices: device.site = self.site device.rack = self.rack @@ -1739,7 +1750,7 @@ class Device(ChangeLoggedModel, ConfigContextModel, CustomFieldModel): """ Return the set of child Devices installed in DeviceBays within this Device. """ - return Device.objects.filter(parent_bay__device=self.pk) + return Device.objects.unrestricted().filter(parent_bay__device=self.pk) def get_status_class(self): return self.STATUS_CLASS_MAP.get(self.status) @@ -1796,7 +1807,7 @@ class VirtualChassis(ChangeLoggedModel): def delete(self, *args, **kwargs): # Check for LAG interfaces split across member chassis - interfaces = Interface.objects.filter( + interfaces = Interface.objects.unrestricted().filter( device__in=self.members.all(), lag__isnull=False ).exclude( @@ -2169,7 +2180,7 @@ class Cable(ChangeLoggedModel): if not hasattr(self, 'termination_a_type'): raise ValidationError('Termination A type has not been specified') try: - self.termination_a_type.model_class().objects.get(pk=self.termination_a_id) + self.termination_a_type.model_class().objects.unrestricted().get(pk=self.termination_a_id) except ObjectDoesNotExist: raise ValidationError({ 'termination_a': 'Invalid ID for type {}'.format(self.termination_a_type) @@ -2179,7 +2190,7 @@ class Cable(ChangeLoggedModel): if not hasattr(self, 'termination_b_type'): raise ValidationError('Termination B type has not been specified') try: - self.termination_b_type.model_class().objects.get(pk=self.termination_b_id) + self.termination_b_type.model_class().objects.unrestricted().get(pk=self.termination_b_id) except ObjectDoesNotExist: raise ValidationError({ 'termination_b': 'Invalid ID for type {}'.format(self.termination_b_type) diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index b630741e9..2db6569a9 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -107,7 +107,7 @@ class SiteTest(APIViewTestCases.APIViewTestCase): Graph.objects.bulk_create(graphs) self.add_permissions('dcim.view_site') - url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.unrestricted().first().pk}) + url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.unrestricted().unrestricted().first().pk}) response = self.client.get(url, **self.header) self.assertEqual(len(response.data), 3) @@ -246,7 +246,7 @@ class RackTest(APIViewTestCases.APIViewTestCase): """ GET a single rack elevation. """ - rack = Rack.objects.first() + rack = Rack.objects.unrestricted().first() self.add_permissions('dcim.view_rack') url = reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}) @@ -266,7 +266,7 @@ class RackTest(APIViewTestCases.APIViewTestCase): """ GET a single rack elevation in SVG format. """ - rack = Rack.objects.first() + rack = Rack.objects.unrestricted().first() self.add_permissions('dcim.view_rack') url = '{}?render=svg'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk})) @@ -281,9 +281,7 @@ class RackReservationTest(APIViewTestCases.APIViewTestCase): @classmethod def setUpTestData(cls): - user = User.objects.create(username='user1', is_active=True) - site = Site.objects.create(name='Test Site 1', slug='test-site-1') cls.racks = ( @@ -908,7 +906,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase): """ Check that creating a device with a duplicate name within a site fails. """ - device = Device.objects.first() + device = Device.objects.unrestricted().first() data = { 'device_type': device.device_type.pk, 'device_role': device.device_role.pk, @@ -1640,11 +1638,11 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Cable.objects.count(), 1) + self.assertEqual(Cable.objects.unrestricted().count(), 1) - cable = Cable.objects.get(pk=response.data['id']) - consoleport1 = ConsolePort.objects.get(pk=consoleport1.pk) - consoleserverport1 = ConsoleServerPort.objects.get(pk=consoleserverport1.pk) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) + consoleport1 = ConsolePort.objects.unrestricted().get(pk=consoleport1.pk) + consoleserverport1 = ConsoleServerPort.objects.unrestricted().get(pk=consoleserverport1.pk) self.assertEqual(cable.termination_a, consoleport1) self.assertEqual(cable.termination_b, consoleserverport1) @@ -1705,12 +1703,12 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - cable = Cable.objects.get(pk=response.data['id']) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) self.assertEqual(cable.termination_a.cable, cable) self.assertEqual(cable.termination_b.cable, cable) - consoleport1 = ConsolePort.objects.get(pk=consoleport1.pk) - consoleserverport1 = ConsoleServerPort.objects.get(pk=consoleserverport1.pk) + consoleport1 = ConsolePort.objects.unrestricted().get(pk=consoleport1.pk) + consoleserverport1 = ConsoleServerPort.objects.unrestricted().get(pk=consoleserverport1.pk) self.assertEqual(consoleport1.connected_endpoint, consoleserverport1) self.assertEqual(consoleserverport1.connected_endpoint, consoleport1) @@ -1735,11 +1733,11 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Cable.objects.count(), 1) + self.assertEqual(Cable.objects.unrestricted().count(), 1) - cable = Cable.objects.get(pk=response.data['id']) - powerport1 = PowerPort.objects.get(pk=powerport1.pk) - poweroutlet1 = PowerOutlet.objects.get(pk=poweroutlet1.pk) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) + powerport1 = PowerPort.objects.unrestricted().get(pk=powerport1.pk) + poweroutlet1 = PowerOutlet.objects.unrestricted().get(pk=poweroutlet1.pk) self.assertEqual(cable.termination_a, powerport1) self.assertEqual(cable.termination_b, poweroutlet1) @@ -1771,11 +1769,11 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Cable.objects.count(), 1) + self.assertEqual(Cable.objects.unrestricted().count(), 1) - cable = Cable.objects.get(pk=response.data['id']) - interface1 = Interface.objects.get(pk=interface1.pk) - interface2 = Interface.objects.get(pk=interface2.pk) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) + interface1 = Interface.objects.unrestricted().get(pk=interface1.pk) + interface2 = Interface.objects.unrestricted().get(pk=interface2.pk) self.assertEqual(cable.termination_a, interface1) self.assertEqual(cable.termination_b, interface2) @@ -1836,12 +1834,12 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - cable = Cable.objects.get(pk=response.data['id']) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) self.assertEqual(cable.termination_a.cable, cable) self.assertEqual(cable.termination_b.cable, cable) - interface1 = Interface.objects.get(pk=interface1.pk) - interface2 = Interface.objects.get(pk=interface2.pk) + interface1 = Interface.objects.unrestricted().get(pk=interface1.pk) + interface2 = Interface.objects.unrestricted().get(pk=interface2.pk) self.assertEqual(interface1.connected_endpoint, interface2) self.assertEqual(interface2.connected_endpoint, interface1) @@ -1875,11 +1873,11 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Cable.objects.count(), 1) + self.assertEqual(Cable.objects.unrestricted().count(), 1) - cable = Cable.objects.get(pk=response.data['id']) - interface1 = Interface.objects.get(pk=interface1.pk) - circuittermination1 = CircuitTermination.objects.get(pk=circuittermination1.pk) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) + interface1 = Interface.objects.unrestricted().get(pk=interface1.pk) + circuittermination1 = CircuitTermination.objects.unrestricted().get(pk=circuittermination1.pk) self.assertEqual(cable.termination_a, interface1) self.assertEqual(cable.termination_b, circuittermination1) @@ -1949,12 +1947,12 @@ class ConnectionTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - cable = Cable.objects.get(pk=response.data['id']) + cable = Cable.objects.unrestricted().get(pk=response.data['id']) self.assertEqual(cable.termination_a.cable, cable) self.assertEqual(cable.termination_b.cable, cable) - interface1 = Interface.objects.get(pk=interface1.pk) - circuittermination1 = CircuitTermination.objects.get(pk=circuittermination1.pk) + interface1 = Interface.objects.unrestricted().get(pk=interface1.pk) + circuittermination1 = CircuitTermination.objects.unrestricted().get(pk=circuittermination1.pk) self.assertEqual(interface1.connected_endpoint, circuittermination1) self.assertEqual(circuittermination1.connected_endpoint, interface1) @@ -2045,12 +2043,12 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase): VirtualChassis(name='Virtual Chassis 3', master=devices[6], domain='domain-3'), ) VirtualChassis.objects.bulk_create(virtual_chassis) - Device.objects.filter(pk=devices[1].pk).update(virtual_chassis=virtual_chassis[0], vc_position=2) - Device.objects.filter(pk=devices[2].pk).update(virtual_chassis=virtual_chassis[0], vc_position=3) - Device.objects.filter(pk=devices[4].pk).update(virtual_chassis=virtual_chassis[1], vc_position=2) - Device.objects.filter(pk=devices[5].pk).update(virtual_chassis=virtual_chassis[1], vc_position=3) - Device.objects.filter(pk=devices[7].pk).update(virtual_chassis=virtual_chassis[2], vc_position=2) - Device.objects.filter(pk=devices[8].pk).update(virtual_chassis=virtual_chassis[2], vc_position=3) + Device.objects.unrestricted().filter(pk=devices[1].pk).update(virtual_chassis=virtual_chassis[0], vc_position=2) + Device.objects.unrestricted().filter(pk=devices[2].pk).update(virtual_chassis=virtual_chassis[0], vc_position=3) + Device.objects.unrestricted().filter(pk=devices[4].pk).update(virtual_chassis=virtual_chassis[1], vc_position=2) + Device.objects.unrestricted().filter(pk=devices[5].pk).update(virtual_chassis=virtual_chassis[1], vc_position=3) + Device.objects.unrestricted().filter(pk=devices[7].pk).update(virtual_chassis=virtual_chassis[2], vc_position=2) + Device.objects.unrestricted().filter(pk=devices[8].pk).update(virtual_chassis=virtual_chassis[2], vc_position=3) cls.update_data = { 'name': 'Virtual Chassis X',