Update views to restrict all querysets

This commit is contained in:
Jeremy Stretch 2020-06-01 11:43:49 -04:00
parent 5574aaa8cb
commit 3c334a0238
9 changed files with 108 additions and 78 deletions

View File

@ -33,7 +33,7 @@ class ProviderView(ObjectView):
def get(self, request, slug):
provider = get_object_or_404(self.queryset, slug=slug)
circuits = Circuit.objects.filter(
circuits = Circuit.objects.restrict(request.user, 'view').filter(
provider=provider
).prefetch_related(
'type', 'tenant', 'terminations__site'
@ -138,12 +138,12 @@ class CircuitView(ObjectView):
def get(self, request, pk):
circuit = get_object_or_404(self.queryset, pk=pk)
termination_a = CircuitTermination.objects.prefetch_related(
termination_a = CircuitTermination.objects.restrict(request.user, 'view').prefetch_related(
'site__region', 'connected_endpoint__device'
).filter(
circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_A
).first()
termination_z = CircuitTermination.objects.prefetch_related(
termination_z = CircuitTermination.objects.restrict(request.user, 'view').prefetch_related(
'site__region', 'connected_endpoint__device'
).filter(
circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_Z

View File

@ -19,8 +19,9 @@ from django.views.generic import View
from circuits.models import Circuit
from extras.models import Graph
from extras.views import ObjectConfigContextView
from ipam.models import Prefix, VLAN
from ipam.models import Prefix, Service, VLAN
from ipam.tables import InterfaceIPAddressTable, InterfaceVLANTable
from secrets.models import Secret
from utilities.forms import ConfirmationForm
from utilities.paginator import EnhancedPaginator
from utilities.permissions import get_permission_for_model
@ -197,14 +198,16 @@ class SiteView(ObjectView):
site = get_object_or_404(self.queryset, slug=slug)
stats = {
'rack_count': Rack.objects.filter(site=site).count(),
'device_count': Device.objects.filter(site=site).count(),
'prefix_count': Prefix.objects.filter(site=site).count(),
'vlan_count': VLAN.objects.filter(site=site).count(),
'circuit_count': Circuit.objects.filter(terminations__site=site).count(),
'vm_count': VirtualMachine.objects.filter(cluster__site=site).count(),
'rack_count': Rack.objects.restrict(request.user, 'view').filter(site=site).count(),
'device_count': Device.objects.restrict(request.user, 'view').filter(site=site).count(),
'prefix_count': Prefix.objects.restrict(request.user, 'view').filter(site=site).count(),
'vlan_count': VLAN.objects.restrict(request.user, 'view').filter(site=site).count(),
'circuit_count': Circuit.objects.restrict(request.user, 'view').filter(terminations__site=site).count(),
'vm_count': VirtualMachine.objects.restrict(request.user, 'view').filter(cluster__site=site).count(),
}
rack_groups = RackGroup.objects.filter(site=site).annotate(rack_count=Count('racks'))
rack_groups = RackGroup.objects.restrict(request.user, 'view').filter(site=site).annotate(
rack_count=Count('racks')
)
show_graphs = Graph.objects.filter(type__model='site').exists()
return render(request, 'dcim/site.html', {
@ -372,7 +375,7 @@ class RackView(ObjectView):
rack = get_object_or_404(self.queryset, pk=pk)
nonracked_devices = Device.objects.filter(
nonracked_devices = Device.objects.restrict(request.user, 'view').filter(
rack=rack,
position__isnull=True,
parent_bay__isnull=True
@ -384,8 +387,8 @@ class RackView(ObjectView):
next_rack = peer_racks.filter(name__gt=rack.name).order_by('name').first()
prev_rack = peer_racks.filter(name__lt=rack.name).order_by('-name').first()
reservations = RackReservation.objects.filter(rack=rack)
power_feeds = PowerFeed.objects.filter(rack=rack).prefetch_related('power_panel')
reservations = RackReservation.objects.restrict(request.user, 'view').filter(rack=rack)
power_feeds = PowerFeed.objects.restrict(request.user, 'view').filter(rack=rack).prefetch_related('power_panel')
return render(request, 'dcim/rack.html', {
'rack': rack,
@ -558,35 +561,35 @@ class DeviceTypeView(ObjectView):
# Component tables
consoleport_table = tables.ConsolePortTemplateTable(
ConsolePortTemplate.objects.filter(device_type=devicetype),
ConsolePortTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
consoleserverport_table = tables.ConsoleServerPortTemplateTable(
ConsoleServerPortTemplate.objects.filter(device_type=devicetype),
ConsoleServerPortTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
powerport_table = tables.PowerPortTemplateTable(
PowerPortTemplate.objects.filter(device_type=devicetype),
PowerPortTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
poweroutlet_table = tables.PowerOutletTemplateTable(
PowerOutletTemplate.objects.filter(device_type=devicetype),
PowerOutletTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
interface_table = tables.InterfaceTemplateTable(
list(InterfaceTemplate.objects.filter(device_type=devicetype)),
list(InterfaceTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype)),
orderable=False
)
front_port_table = tables.FrontPortTemplateTable(
FrontPortTemplate.objects.filter(device_type=devicetype),
FrontPortTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
rear_port_table = tables.RearPortTemplateTable(
RearPortTemplate.objects.filter(device_type=devicetype),
RearPortTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
devicebay_table = tables.DeviceBayTemplateTable(
DeviceBayTemplate.objects.filter(device_type=devicetype),
DeviceBayTemplate.objects.restrict(request.user, 'view').filter(device_type=devicetype),
orderable=False
)
if request.user.has_perm('dcim.change_devicetype'):
@ -995,47 +998,61 @@ class DeviceView(ObjectView):
# VirtualChassis members
if device.virtual_chassis is not None:
vc_members = Device.objects.filter(
vc_members = Device.objects.restrict(request.user, 'view').filter(
virtual_chassis=device.virtual_chassis
).order_by('vc_position')
else:
vc_members = []
# Console ports
console_ports = device.consoleports.prefetch_related('connected_endpoint__device', 'cable')
console_ports = ConsolePort.objects.restrict(request.user, 'view').filter(device=device).prefetch_related(
'connected_endpoint__device', 'cable',
)
# Console server ports
consoleserverports = device.consoleserverports.prefetch_related('connected_endpoint__device', 'cable')
consoleserverports = ConsoleServerPort.objects.restrict(request.user, 'view').filter(
device=device
).prefetch_related(
'connected_endpoint__device', 'cable',
)
# Power ports
power_ports = device.powerports.prefetch_related('_connected_poweroutlet__device', 'cable')
power_ports = PowerPort.objects.restrict(request.user, 'view').filter(device=device).prefetch_related(
'_connected_poweroutlet__device', 'cable',
)
# Power outlets
poweroutlets = device.poweroutlets.prefetch_related('connected_endpoint__device', 'cable', 'power_port')
poweroutlets = PowerOutlet.objects.restrict(request.user, 'view').filter(device=device).prefetch_related(
'connected_endpoint__device', 'cable', 'power_port',
)
# Interfaces
interfaces = device.vc_interfaces.prefetch_related(
interfaces = device.vc_interfaces.restrict(request.user, 'view').filter(device=device).prefetch_related(
'lag', '_connected_interface__device', '_connected_circuittermination__circuit', 'cable',
'cable__termination_a', 'cable__termination_b', 'ip_addresses', 'tags'
)
# Front ports
front_ports = device.frontports.prefetch_related('rear_port', 'cable')
front_ports = FrontPort.objects.restrict(request.user, 'view').filter(device=device).prefetch_related(
'rear_port', 'cable',
)
# Rear ports
rear_ports = device.rearports.prefetch_related('cable')
rear_ports = RearPort.objects.restrict(request.user, 'view').filter(device=device).prefetch_related('cable')
# Device bays
device_bays = device.device_bays.prefetch_related('installed_device__device_type__manufacturer')
device_bays = DeviceBay.objects.restrict(request.user, 'view').filter(device=device).prefetch_related(
'installed_device__device_type__manufacturer',
)
# Services
services = device.services.all()
services = Service.objects.restrict(request.user, 'view').filter(device=device)
# Secrets
secrets = device.secrets.all()
secrets = Secret.objects.restrict(request.user, 'view').filter(device=device)
# Find up to ten devices in the same site with the same functional role for quick reference.
related_devices = Device.objects.filter(
related_devices = Device.objects.restrict(request.user, 'view').filter(
site=device.site, device_role=device.device_role
).exclude(
pk=device.pk
@ -1068,7 +1085,7 @@ class DeviceInventoryView(ObjectView):
def get(self, request, pk):
device = get_object_or_404(self.queryset, pk=pk)
inventory_items = InventoryItem.objects.filter(
inventory_items = InventoryItem.objects.restrict(request.user, 'view').filter(
device=device, parent=None
).prefetch_related(
'manufacturer', 'child_items'
@ -1102,7 +1119,9 @@ class DeviceLLDPNeighborsView(ObjectView):
def get(self, request, pk):
device = get_object_or_404(self.queryset, pk=pk)
interfaces = device.vc_interfaces.exclude(type__in=NONCONNECTABLE_IFACE_TYPES).prefetch_related(
interfaces = device.vc_interfaces.restrict(request.user, 'view').exclude(
type__in=NONCONNECTABLE_IFACE_TYPES
).prefetch_related(
'_connected_interface__device'
)
@ -1423,7 +1442,7 @@ class InterfaceView(ObjectView):
# Get assigned IP addresses
ipaddress_table = InterfaceIPAddressTable(
data=interface.ip_addresses.prefetch_related('vrf', 'tenant'),
data=interface.ip_addresses.restrict(request.user, 'view').prefetch_related('vrf', 'tenant'),
orderable=False
)

View File

@ -163,7 +163,7 @@ class ObjectConfigContextView(ObjectView):
def get(self, request, pk):
obj = get_object_or_404(self.queryset, pk=pk)
source_contexts = ConfigContext.objects.get_for_object(obj)
source_contexts = ConfigContext.objects.restrict(request.user, 'view').get_for_object(obj)
model_name = self.queryset.model._meta.model_name
# Determine user's preferred output format
@ -207,13 +207,17 @@ class ObjectChangeView(ObjectView):
objectchange = get_object_or_404(self.queryset, pk=pk)
related_changes = ObjectChange.objects.filter(request_id=objectchange.request_id).exclude(pk=objectchange.pk)
related_changes = ObjectChange.objects.restrict(request.user, 'view').filter(
request_id=objectchange.request_id
).exclude(
pk=objectchange.pk
)
related_changes_table = ObjectChangeTable(
data=related_changes[:50],
orderable=False
)
objectchanges = ObjectChange.objects.filter(
objectchanges = ObjectChange.objects.restrict(request.user, 'view').filter(
changed_object_type=objectchange.changed_object_type,
changed_object_id=objectchange.changed_object_id,
)
@ -255,7 +259,7 @@ class ObjectChangeLogView(View):
# Gather all changes for this object (and its related objects)
content_type = ContentType.objects.get_for_model(model)
objectchanges = ObjectChange.objects.prefetch_related(
objectchanges = ObjectChange.objects.restrict(request.user, 'view').prefetch_related(
'user', 'changed_object_type'
).filter(
Q(changed_object_type=content_type, changed_object_id=obj.pk) |

View File

@ -1,10 +1,10 @@
from django.db import models
from django.db.models import Manager
from ipam.lookups import Host, Inet
from utilities.querysets import RestrictedQuerySet
class IPAddressManager(models.Manager):
class IPAddressManager(Manager.from_queryset(RestrictedQuerySet)):
def get_queryset(self):
"""
@ -14,5 +14,4 @@ class IPAddressManager(models.Manager):
then re-cast this value to INET() so that records will be ordered properly. We are essentially re-casting each
IP address as a /32 or /128.
"""
qs = RestrictedQuerySet(self.model, using=self._db)
return qs.order_by(Inet(Host('address')))
return super().get_queryset().order_by(Inet(Host('address')))

View File

@ -3,14 +3,13 @@ from django.conf import settings
from django.db.models import Count, Q
from django.db.models.expressions import RawSQL
from django.shortcuts import get_object_or_404, redirect, render
from django.views.generic import View
from django_tables2 import RequestConfig
from dcim.models import Device, Interface
from utilities.paginator import EnhancedPaginator
from utilities.views import (
BulkCreateView, BulkDeleteView, BulkEditView, BulkImportView, ObjectView, ObjectDeleteView, ObjectEditView,
ObjectListView, ObjectPermissionRequiredMixin,
ObjectListView,
)
from virtualization.models import VirtualMachine
from . import filters, forms, tables
@ -125,7 +124,7 @@ class VRFView(ObjectView):
def get(self, request, pk):
vrf = get_object_or_404(self.queryset, pk=pk)
prefix_count = Prefix.objects.filter(vrf=vrf).count()
prefix_count = Prefix.objects.restrict(request.user, 'view').filter(vrf=vrf).count()
return render(request, 'ipam/vrf.html', {
'vrf': vrf,
@ -305,7 +304,7 @@ class AggregateView(ObjectView):
aggregate = get_object_or_404(self.queryset, pk=pk)
# Find all child prefixes contained by this aggregate
child_prefixes = Prefix.objects.filter(
child_prefixes = Prefix.objects.restrict(request.user, 'view').filter(
prefix__net_contained_or_equal=str(aggregate.prefix)
).prefetch_related(
'site', 'role'
@ -429,12 +428,14 @@ class PrefixView(ObjectView):
prefix = get_object_or_404(self.queryset, pk=pk)
try:
aggregate = Aggregate.objects.get(prefix__net_contains_or_equals=str(prefix.prefix))
aggregate = Aggregate.objects.restrict(request.user, 'view').get(
prefix__net_contains_or_equals=str(prefix.prefix)
)
except Aggregate.DoesNotExist:
aggregate = None
# Parent prefixes table
parent_prefixes = Prefix.objects.filter(
parent_prefixes = Prefix.objects.restrict(request.user, 'view').filter(
Q(vrf=prefix.vrf) | Q(vrf__isnull=True)
).filter(
prefix__net_contains=str(prefix.prefix)
@ -445,7 +446,7 @@ class PrefixView(ObjectView):
parent_prefix_table.exclude = ('vrf',)
# Duplicate prefixes table
duplicate_prefixes = Prefix.objects.filter(
duplicate_prefixes = Prefix.objects.restrict(request.user, 'view').filter(
vrf=prefix.vrf, prefix=str(prefix.prefix)
).exclude(
pk=prefix.pk
@ -471,7 +472,7 @@ class PrefixPrefixesView(ObjectView):
prefix = get_object_or_404(self.queryset, pk=pk)
# Child prefixes table
child_prefixes = prefix.get_child_prefixes().prefetch_related(
child_prefixes = prefix.get_child_prefixes().restrict(request.user, 'view').prefetch_related(
'site', 'vlan', 'role',
).annotate_depth(limit=0)
@ -515,7 +516,7 @@ class PrefixIPAddressesView(ObjectView):
prefix = get_object_or_404(self.queryset, pk=pk)
# Find all IPAddresses belonging to this Prefix
ipaddresses = prefix.get_child_ips().prefetch_related(
ipaddresses = prefix.get_child_ips().restrict(request.user, 'view').prefetch_related(
'vrf', 'interface__device', 'primary_ip4_for', 'primary_ip6_for'
)
@ -607,7 +608,7 @@ class IPAddressView(ObjectView):
ipaddress = get_object_or_404(self.queryset, pk=pk)
# Parent prefixes table
parent_prefixes = Prefix.objects.filter(
parent_prefixes = Prefix.objects.restrict(request.user, 'view').filter(
vrf=ipaddress.vrf, prefix__net_contains=str(ipaddress.address.ip)
).prefetch_related(
'site', 'role'
@ -616,7 +617,7 @@ class IPAddressView(ObjectView):
parent_prefixes_table.exclude = ('vrf',)
# Duplicate IPs table
duplicate_ips = IPAddress.objects.filter(
duplicate_ips = IPAddress.objects.restrict(request.user, 'view').filter(
vrf=ipaddress.vrf, address=str(ipaddress.address)
).exclude(
pk=ipaddress.pk
@ -629,14 +630,13 @@ class IPAddressView(ObjectView):
duplicate_ips_table = tables.IPAddressTable(list(duplicate_ips), orderable=False)
# Related IP table
related_ips = IPAddress.objects.prefetch_related(
related_ips = IPAddress.objects.restrict(request.user, 'view').prefetch_related(
'interface__device'
).exclude(
address=str(ipaddress.address)
).filter(
vrf=ipaddress.vrf, address__net_contained_or_equal=str(ipaddress.address)
)
related_ips_table = tables.IPAddressTable(related_ips, orderable=False)
paginate = {
@ -785,7 +785,7 @@ class VLANGroupVLANsView(ObjectView):
def get(self, request, pk):
vlan_group = get_object_or_404(self.queryset, pk=pk)
vlans = VLAN.objects.filter(group_id=pk)
vlans = VLAN.objects.restrict(request.user, 'view').filter(group_id=pk)
vlans = add_available_vlans(vlan_group, vlans)
vlan_table = tables.VLANDetailTable(vlans)
@ -832,7 +832,9 @@ class VLANView(ObjectView):
def get(self, request, pk):
vlan = get_object_or_404(self.queryset, pk=pk)
prefixes = Prefix.objects.filter(vlan=vlan).prefetch_related('vrf', 'site', 'role')
prefixes = Prefix.objects.restrict(request.user, 'view').filter(vlan=vlan).prefetch_related(
'vrf', 'site', 'role'
)
prefix_table = tables.PrefixTable(list(prefixes), orderable=False)
prefix_table.exclude = ('vlan',)
@ -848,7 +850,7 @@ class VLANMembersView(ObjectView):
def get(self, request, pk):
vlan = get_object_or_404(self.queryset, pk=pk)
members = vlan.get_members().prefetch_related('device', 'virtual_machine')
members = vlan.get_members().restrict(request.user, 'view').prefetch_related('device', 'virtual_machine')
members_table = tables.VLANMemberTable(members)

View File

@ -64,17 +64,17 @@ class TenantView(ObjectView):
tenant = get_object_or_404(self.queryset, slug=slug)
stats = {
'site_count': Site.objects.filter(tenant=tenant).count(),
'rack_count': Rack.objects.filter(tenant=tenant).count(),
'rackreservation_count': RackReservation.objects.filter(tenant=tenant).count(),
'device_count': Device.objects.filter(tenant=tenant).count(),
'vrf_count': VRF.objects.filter(tenant=tenant).count(),
'prefix_count': Prefix.objects.filter(tenant=tenant).count(),
'ipaddress_count': IPAddress.objects.filter(tenant=tenant).count(),
'vlan_count': VLAN.objects.filter(tenant=tenant).count(),
'circuit_count': Circuit.objects.filter(tenant=tenant).count(),
'virtualmachine_count': VirtualMachine.objects.filter(tenant=tenant).count(),
'cluster_count': Cluster.objects.filter(tenant=tenant).count(),
'site_count': Site.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'rack_count': Rack.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'rackreservation_count': RackReservation.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'device_count': Device.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'vrf_count': VRF.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'prefix_count': Prefix.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'ipaddress_count': IPAddress.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'vlan_count': VLAN.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'circuit_count': Circuit.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'virtualmachine_count': VirtualMachine.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
'cluster_count': Cluster.objects.restrict(request.user, 'view').filter(tenant=tenant).count(),
}
return render(request, 'tenancy/tenant.html', {

View File

@ -28,15 +28,22 @@ class RestrictedQuerySet(QuerySet):
model_name = self.model._meta.model_name
permission_required = f'{app_label}.{action}_{model_name}'
# TODO: Handle anonymous users
if not user.is_authenticated:
return self
# Determine what constraints (if any) have been placed on this user for this action and model
# TODO: Find a better way to ensure permissions are cached
if not hasattr(user, '_object_perm_cache'):
user.get_all_permissions()
obj_perm_attrs = user._object_perm_cache[permission_required]
# User has not been granted any permission
if permission_required not in user._object_perm_cache:
return self.none()
# Filter the queryset to include only objects with allowed attributes
attrs = Q()
for perm_attrs in obj_perm_attrs:
for perm_attrs in user._object_perm_cache[permission_required]:
if perm_attrs:
attrs |= Q(**perm_attrs)

View File

@ -187,7 +187,6 @@ class VirtualMachineTestCase(ViewTestCases.PrimaryObjectViewTestCase):
# TODO: Update base class to DeviceComponentViewTestCase
class InterfaceTestCase(
ViewTestCases.GetObjectViewTestCase,
ViewTestCases.EditObjectViewTestCase,
ViewTestCases.DeleteObjectViewTestCase,
ViewTestCases.BulkCreateObjectsViewTestCase,

View File

@ -89,7 +89,7 @@ class ClusterView(ObjectView):
def get(self, request, pk):
cluster = get_object_or_404(self.queryset, pk=pk)
devices = Device.objects.filter(cluster=cluster).prefetch_related(
devices = Device.objects.restrict(request.user, 'view').filter(cluster=cluster).prefetch_related(
'site', 'rack', 'tenant', 'device_type__manufacturer'
)
device_table = DeviceTable(list(devices), orderable=False)
@ -235,8 +235,8 @@ class VirtualMachineView(ObjectView):
def get(self, request, pk):
virtualmachine = get_object_or_404(self.queryset, pk=pk)
interfaces = Interface.objects.filter(virtual_machine=virtualmachine)
services = Service.objects.filter(virtual_machine=virtualmachine)
interfaces = Interface.objects.restrict(request.user, 'view').filter(virtual_machine=virtualmachine)
services = Service.objects.restrict(request.user, 'view').filter(virtual_machine=virtualmachine)
return render(request, 'virtualization/virtualmachine.html', {
'virtualmachine': virtualmachine,