diff --git a/netbox/circuits/tests/test_api.py b/netbox/circuits/tests/test_api.py index 424b13d40..fd9e87412 100644 --- a/netbox/circuits/tests/test_api.py +++ b/netbox/circuits/tests/test_api.py @@ -15,7 +15,7 @@ class AppTest(APITestCase): self.assertEqual(response.status_code, 200) -class ProviderTest(APIViewTestCases.APIViewTestCase): +class ProviderTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase): model = Provider brief_fields = ['circuit_count', 'display', 'id', 'name', 'slug', 'url'] create_data = [ @@ -47,7 +47,7 @@ class ProviderTest(APIViewTestCases.APIViewTestCase): Provider.objects.bulk_create(providers) -class CircuitTypeTest(APIViewTestCases.APIViewTestCase): +class CircuitTypeTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase): model = CircuitType brief_fields = ['circuit_count', 'display', 'id', 'name', 'slug', 'url'] create_data = ( @@ -79,7 +79,7 @@ class CircuitTypeTest(APIViewTestCases.APIViewTestCase): CircuitType.objects.bulk_create(circuit_types) -class CircuitTest(APIViewTestCases.APIViewTestCase): +class CircuitTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase): model = Circuit brief_fields = ['cid', 'display', 'id', 'url'] bulk_update_data = { @@ -127,7 +127,7 @@ class CircuitTest(APIViewTestCases.APIViewTestCase): ] -class CircuitTerminationTest(APIViewTestCases.APIViewTestCase): +class CircuitTerminationTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase): model = CircuitTermination brief_fields = ['_occupied', 'cable', 'circuit', 'display', 'id', 'term_side', 'url'] @@ -180,7 +180,7 @@ class CircuitTerminationTest(APIViewTestCases.APIViewTestCase): } -class ProviderNetworkTest(APIViewTestCases.APIViewTestCase): +class ProviderNetworkTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase): model = ProviderNetwork brief_fields = ['display', 'id', 'name', 'url'] diff --git a/netbox/netbox/tests/test_graphql.py b/netbox/netbox/tests/test_graphql.py new file mode 100644 index 000000000..dd43bbbdd --- /dev/null +++ b/netbox/netbox/tests/test_graphql.py @@ -0,0 +1,27 @@ +from django.test import override_settings +from django.urls import reverse + +from utilities.testing import disable_warnings, TestCase + + +class GraphQLTestCase(TestCase): + + @override_settings(LOGIN_REQUIRED=True) + def test_graphiql_interface(self): + """ + Test rendering of the GraphiQL interactive web interface + """ + url = reverse('graphql') + header = { + 'HTTP_ACCEPT': 'text/html', + } + + # Authenticated request + response = self.client.get(url, **header) + self.assertHttpStatus(response, 200) + + # Non-authenticated request + self.client.logout() + response = self.client.get(url, **header) + with disable_warnings('django.request'): + self.assertHttpStatus(response, 302) diff --git a/netbox/netbox/urls.py b/netbox/netbox/urls.py index 9257f12b9..4f1ec38d2 100644 --- a/netbox/netbox/urls.py +++ b/netbox/netbox/urls.py @@ -63,7 +63,7 @@ _patterns = [ re_path(r'^api/swagger(?P.json|.yaml)$', schema_view.without_ui(), name='schema_swagger'), # GraphQL - path('graphql/', GraphQLView.as_view(graphiql=True, schema=schema)), + path('graphql/', GraphQLView.as_view(graphiql=True, schema=schema), name='graphql'), # Serving static media in Django to pipe it through LoginRequiredMiddleware path('media/', serve, {'document_root': settings.MEDIA_ROOT}), diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index b57c273fd..1a9414dc6 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -1,3 +1,5 @@ +import json + from django.conf import settings from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType @@ -421,6 +423,49 @@ class APIViewTestCases: self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertEqual(self._get_queryset().count(), initial_count - 3) + class GraphQLTestCase(APITestCase): + + def test_graphql_get_object(self): + url = reverse('graphql') + object_type = self.model._meta.verbose_name.replace(' ', '_') + object_id = self._get_queryset().first().pk + query = f""" + {{ + {object_type}(id:{object_id}) {{ + id + }} + }} + """ + + # Non-authenticated requests should fail + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) + + response = self.client.post(url, data={'query': query}, **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + + def test_graphql_list_objects(self): + url = reverse('graphql') + object_type = self.model._meta.verbose_name_plural.replace(' ', '_') + query = f""" + {{ + {object_type} {{ + id + }} + }} + """ + + # Non-authenticated requests should fail + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) + + response = self.client.post(url, data={'query': query}, **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + class APIViewTestCase( GetObjectViewTestCase, ListObjectsViewTestCase,