Call restrict() when retrieving related Graphs

This commit is contained in:
Jeremy Stretch 2020-06-29 10:02:00 -04:00
parent 86d1370512
commit 0dbe248df8
4 changed files with 12 additions and 12 deletions

View File

@ -28,8 +28,8 @@ class ProviderViewSet(CustomFieldModelViewSet):
""" """
A convenience method for rendering graphs for a particular provider. A convenience method for rendering graphs for a particular provider.
""" """
provider = get_object_or_404(Provider, pk=pk) provider = get_object_or_404(self.queryset, pk=pk)
queryset = Graph.objects.filter(type__model='provider') queryset = Graph.objects.restrict(request.user).filter(type__model='provider')
serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': provider}) serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': provider})
return Response(serializer.data) return Response(serializer.data)

View File

@ -49,7 +49,7 @@ class ProviderTest(APIViewTestCases.APIViewTestCase):
""" """
Test retrieval of Graphs assigned to Providers. Test retrieval of Graphs assigned to Providers.
""" """
provider = self.model.objects.first() provider = self.model.objects.unrestricted().first()
ct = ContentType.objects.get(app_label='circuits', model='provider') ct = ContentType.objects.get(app_label='circuits', model='provider')
graphs = ( graphs = (
Graph(type=ct, name='Graph 1', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=1'), Graph(type=ct, name='Graph 1', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=1'),

View File

@ -103,8 +103,8 @@ class SiteViewSet(CustomFieldModelViewSet):
""" """
A convenience method for rendering graphs for a particular site. A convenience method for rendering graphs for a particular site.
""" """
site = get_object_or_404(Site, pk=pk) site = get_object_or_404(self.queryset, pk=pk)
queryset = Graph.objects.filter(type__model='site') queryset = Graph.objects.restrict(request.user).filter(type__model='site')
serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': site}) serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': site})
return Response(serializer.data) return Response(serializer.data)
@ -347,8 +347,8 @@ class DeviceViewSet(CustomFieldModelViewSet):
""" """
A convenience method for rendering graphs for a particular Device. A convenience method for rendering graphs for a particular Device.
""" """
device = get_object_or_404(Device, pk=pk) device = get_object_or_404(self.queryset, pk=pk)
queryset = Graph.objects.filter(type__model='device') queryset = Graph.objects.restrict(request.user).filter(type__model='device')
serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': device}) serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': device})
return Response(serializer.data) return Response(serializer.data)
@ -496,8 +496,8 @@ class InterfaceViewSet(CableTraceMixin, ModelViewSet):
""" """
A convenience method for rendering graphs for a particular interface. A convenience method for rendering graphs for a particular interface.
""" """
interface = get_object_or_404(Interface, pk=pk) interface = get_object_or_404(self.queryset, pk=pk)
queryset = Graph.objects.filter(type__model='interface') queryset = Graph.objects.restrict(request.user).filter(type__model='interface')
serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': interface}) serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': interface})
return Response(serializer.data) return Response(serializer.data)

View File

@ -107,7 +107,7 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('dcim.view_site') self.add_permissions('dcim.view_site')
url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.first().pk}) url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.unrestricted().first().pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
@ -878,7 +878,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('dcim.view_device') self.add_permissions('dcim.view_device')
url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.first().pk}) url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.unrestricted().first().pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
@ -1245,7 +1245,7 @@ class InterfaceTest(APIViewTestCases.APIViewTestCase):
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('dcim.view_interface') self.add_permissions('dcim.view_interface')
url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.first().pk}) url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.unrestricted().first().pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)