9856 GraphQLView

This commit is contained in:
Arthur 2024-03-05 08:30:34 -08:00
parent 13bf2c1940
commit 14f04453bb
6 changed files with 33 additions and 39 deletions

View File

@ -541,7 +541,7 @@ class LocationType(VLANGroupsMixin, ImageAttachmentsMixin, ContactsMixin, Organi
return self.vlan_groups.all() return self.vlan_groups.all()
@strawberry_django.field @strawberry_django.field
def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')]: def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')] | None:
return self.parent return self.parent
@strawberry_django.field @strawberry_django.field

View File

@ -24,13 +24,13 @@ if TYPE_CHECKING:
class ChangelogMixin: class ChangelogMixin:
@strawberry_django.field @strawberry_django.field
def changelog(self) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]: def changelog(self, info) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]:
content_type = ContentType.objects.get_for_model(self) content_type = ContentType.objects.get_for_model(self)
object_changes = ObjectChange.objects.filter( object_changes = ObjectChange.objects.filter(
changed_object_type=content_type, changed_object_type=content_type,
changed_object_id=self.pk changed_object_id=self.pk
) )
return object_changes.restrict(info.context.user, 'view') return object_changes.restrict(info.context.request.user, 'view')
@strawberry.type @strawberry.type
@ -53,16 +53,16 @@ class CustomFieldsMixin:
class ImageAttachmentsMixin: class ImageAttachmentsMixin:
@strawberry_django.field @strawberry_django.field
def image_attachments(self) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]: def image_attachments(self, info) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]:
return self.images.restrict(info.context.user, 'view') return self.images.restrict(info.context.request.user, 'view')
@strawberry.type @strawberry.type
class JournalEntriesMixin: class JournalEntriesMixin:
@strawberry_django.field @strawberry_django.field
def journal_entries(self) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]: def journal_entries(self, info) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]:
return self.journal_entries.restrict(info.context.user, 'view') return self.journal_entries.restrict(info.context.request.user, 'view')
@strawberry.type @strawberry.type

View File

@ -10,11 +10,11 @@ class IPAddressesMixin:
ip_addresses = graphene.List('ipam.graphql.types.IPAddressType') ip_addresses = graphene.List('ipam.graphql.types.IPAddressType')
def resolve_ip_addresses(self, info): def resolve_ip_addresses(self, info):
return self.ip_addresses.restrict(info.context.user, 'view') return self.ip_addresses.restrict(info.context.request.user, 'view')
class VLANGroupsMixin: class VLANGroupsMixin:
vlan_groups = graphene.List('ipam.graphql.types.VLANGroupType') vlan_groups = graphene.List('ipam.graphql.types.VLANGroupType')
def resolve_vlan_groups(self, info): def resolve_vlan_groups(self, info):
return self.vlan_groups.restrict(info.context.user, 'view') return self.vlan_groups.restrict(info.context.request.user, 'view')

View File

@ -2,19 +2,21 @@ from django.conf import settings
from django.contrib.auth.views import redirect_to_login from django.contrib.auth.views import redirect_to_login
from django.http import HttpResponseNotFound, HttpResponseForbidden from django.http import HttpResponseNotFound, HttpResponseForbidden
from django.urls import reverse from django.urls import reverse
from graphene_django.views import GraphQLView as GraphQLView_ from django.views.decorators.csrf import csrf_exempt
from rest_framework.exceptions import AuthenticationFailed from rest_framework.exceptions import AuthenticationFailed
from strawberry.django.views import GraphQLView
from netbox.api.authentication import TokenAuthentication from netbox.api.authentication import TokenAuthentication
from netbox.config import get_config from netbox.config import get_config
class GraphQLView(GraphQLView_): class NetBoxGraphQLView(GraphQLView):
""" """
Extends graphene_django's GraphQLView to support DRF's token-based authentication. Extends strawberry's GraphQLView to support DRF's token-based authentication.
""" """
graphiql_template = 'graphiql.html' graphiql_template = 'graphiql.html'
@csrf_exempt
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
config = get_config() config = get_config()

View File

@ -1,14 +1,13 @@
from django.conf import settings from django.conf import settings
from django.conf.urls import include from django.conf.urls import include
from django.urls import path from django.urls import path
from django.views.decorators.csrf import csrf_exempt
from django.views.static import serve from django.views.static import serve
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
from account.views import LoginView, LogoutView from account.views import LoginView, LogoutView
from netbox.api.views import APIRootView, StatusView from netbox.api.views import APIRootView, StatusView
from netbox.graphql.schema import schema from netbox.graphql.schema import schema
from netbox.graphql.views import GraphQLView from netbox.graphql.views import NetBoxGraphQLView
from netbox.plugins.urls import plugin_patterns, plugin_api_patterns from netbox.plugins.urls import plugin_patterns, plugin_api_patterns
from netbox.views import HomeView, StaticMediaFailureView, SearchView, htmx from netbox.views import HomeView, StaticMediaFailureView, SearchView, htmx
from strawberry.django.views import GraphQLView from strawberry.django.views import GraphQLView
@ -61,7 +60,7 @@ _patterns = [
path('api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='api_redocs'), path('api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='api_redocs'),
# GraphQL # GraphQL
path('graphql/', GraphQLView.as_view(schema=schema), name='graphql'), path('graphql/', NetBoxGraphQLView.as_view(schema=schema), name='graphql'),
# Serving static media in Django to pipe it through LoginRequiredMiddleware # Serving static media in Django to pipe it through LoginRequiredMiddleware
path('media/<path:path>', serve, {'document_root': settings.MEDIA_ROOT}), path('media/<path:path>', serve, {'document_root': settings.MEDIA_ROOT}),

View File

@ -1,5 +1,6 @@
import inspect import inspect
import json import json
import strawberry_django
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
@ -18,7 +19,7 @@ from .base import ModelTestCase
from .utils import disable_warnings from .utils import disable_warnings
from ipam.graphql.types import IPAddressFamilyType from ipam.graphql.types import IPAddressFamilyType
from strawberry.type import StrawberryList
__all__ = ( __all__ = (
'APITestCase', 'APITestCase',
@ -447,36 +448,26 @@ class APIViewTestCases:
# Compile list of fields to include # Compile list of fields to include
fields_string = '' fields_string = ''
for field_name, field in type_class.__dataclass_fields__.items(): for field in type_class.__strawberry_definition__.fields:
# for field_name, field in type_class._meta.fields.items(): # for field_name, field in type_class._meta.fields.items():
print(f"field_name: {field_name} field: {field}") print(f"field_name: {field.name} type: {field.type}")
is_string_array = False
if type(field.type) is GQLList:
if field.type.of_type is GQLString:
is_string_array = True
elif type(field.type.of_type) is GQLNonNull and field.type.of_type.of_type is GQLString:
is_string_array = True
if type(field) is GQLDynamic: if type(field.type) is StrawberryList:
fields_string += f'{field.name} {{ id }}\n'
elif field.type is strawberry_django.fields.types.DjangoModelType:
# Dynamic fields must specify a subselection # Dynamic fields must specify a subselection
fields_string += f'{field_name} {{ id }}\n' fields_string += f'{field.name} {{ id }}\n'
# TODO: Improve field detection logic to avoid nested ArrayFields # TODO: Improve field detection logic to avoid nested ArrayFields
elif field_name == 'extra_choices': elif field.name == 'extra_choices':
continue continue
elif inspect.isclass(field.type) and issubclass(field.type, GQLUnion): # elif type(field.type) is GQLList and not is_string_array:
# Union types dont' have an id or consistent values # # TODO: Come up with something more elegant
continue # # Temporary hack to support automated testing of reverse generic relations
elif type(field.type) is GQLList and inspect.isclass(field.type.of_type) and issubclass(field.type.of_type, GQLUnion): # fields_string += f'{field_name} {{ id }}\n'
# Union types dont' have an id or consistent values
continue
elif type(field.type) is GQLList and not is_string_array:
# TODO: Come up with something more elegant
# Temporary hack to support automated testing of reverse generic relations
fields_string += f'{field_name} {{ id }}\n'
elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType): elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType):
fields_string += f'{field_name} {{ value, label }}\n' fields_string += f'{field.name} {{ value, label }}\n'
else: else:
fields_string += f'{field_name}\n' fields_string += f'{field.name}\n'
query = f""" query = f"""
{{ {{
@ -486,6 +477,7 @@ class APIViewTestCases:
}} }}
""" """
print(query)
return query return query
@override_settings(LOGIN_REQUIRED=True) @override_settings(LOGIN_REQUIRED=True)
@ -498,6 +490,7 @@ class APIViewTestCases:
# Non-authenticated requests should fail # Non-authenticated requests should fail
with disable_warnings('django.request'): with disable_warnings('django.request'):
print(f"url: {url}")
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
# Add object-level permission # Add object-level permission