diff --git a/netbox/dcim/graphql/types.py b/netbox/dcim/graphql/types.py index 7f238b987..81df74d4a 100644 --- a/netbox/dcim/graphql/types.py +++ b/netbox/dcim/graphql/types.py @@ -541,7 +541,7 @@ class LocationType(VLANGroupsMixin, ImageAttachmentsMixin, ContactsMixin, Organi return self.vlan_groups.all() @strawberry_django.field - def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')]: + def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')] | None: return self.parent @strawberry_django.field diff --git a/netbox/extras/graphql/mixins.py b/netbox/extras/graphql/mixins.py index 04c06c9c3..91014fbbd 100644 --- a/netbox/extras/graphql/mixins.py +++ b/netbox/extras/graphql/mixins.py @@ -24,13 +24,13 @@ if TYPE_CHECKING: class ChangelogMixin: @strawberry_django.field - def changelog(self) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]: + def changelog(self, info) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]: content_type = ContentType.objects.get_for_model(self) object_changes = ObjectChange.objects.filter( changed_object_type=content_type, changed_object_id=self.pk ) - return object_changes.restrict(info.context.user, 'view') + return object_changes.restrict(info.context.request.user, 'view') @strawberry.type @@ -53,16 +53,16 @@ class CustomFieldsMixin: class ImageAttachmentsMixin: @strawberry_django.field - def image_attachments(self) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]: - return self.images.restrict(info.context.user, 'view') + def image_attachments(self, info) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]: + return self.images.restrict(info.context.request.user, 'view') @strawberry.type class JournalEntriesMixin: @strawberry_django.field - def journal_entries(self) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]: - return self.journal_entries.restrict(info.context.user, 'view') + def journal_entries(self, info) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]: + return self.journal_entries.restrict(info.context.request.user, 'view') @strawberry.type diff --git a/netbox/ipam/graphql/mixins.py b/netbox/ipam/graphql/mixins.py index 283414df3..38c7657a5 100644 --- a/netbox/ipam/graphql/mixins.py +++ b/netbox/ipam/graphql/mixins.py @@ -10,11 +10,11 @@ class IPAddressesMixin: ip_addresses = graphene.List('ipam.graphql.types.IPAddressType') def resolve_ip_addresses(self, info): - return self.ip_addresses.restrict(info.context.user, 'view') + return self.ip_addresses.restrict(info.context.request.user, 'view') class VLANGroupsMixin: vlan_groups = graphene.List('ipam.graphql.types.VLANGroupType') def resolve_vlan_groups(self, info): - return self.vlan_groups.restrict(info.context.user, 'view') + return self.vlan_groups.restrict(info.context.request.user, 'view') diff --git a/netbox/netbox/graphql/views.py b/netbox/netbox/graphql/views.py index e1573dba6..d39f13807 100644 --- a/netbox/netbox/graphql/views.py +++ b/netbox/netbox/graphql/views.py @@ -2,19 +2,21 @@ from django.conf import settings from django.contrib.auth.views import redirect_to_login from django.http import HttpResponseNotFound, HttpResponseForbidden from django.urls import reverse -from graphene_django.views import GraphQLView as GraphQLView_ +from django.views.decorators.csrf import csrf_exempt from rest_framework.exceptions import AuthenticationFailed +from strawberry.django.views import GraphQLView from netbox.api.authentication import TokenAuthentication from netbox.config import get_config -class GraphQLView(GraphQLView_): +class NetBoxGraphQLView(GraphQLView): """ - Extends graphene_django's GraphQLView to support DRF's token-based authentication. + Extends strawberry's GraphQLView to support DRF's token-based authentication. """ graphiql_template = 'graphiql.html' + @csrf_exempt def dispatch(self, request, *args, **kwargs): config = get_config() diff --git a/netbox/netbox/urls.py b/netbox/netbox/urls.py index cf1086f99..1ce929513 100644 --- a/netbox/netbox/urls.py +++ b/netbox/netbox/urls.py @@ -1,14 +1,13 @@ from django.conf import settings from django.conf.urls import include from django.urls import path -from django.views.decorators.csrf import csrf_exempt from django.views.static import serve from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from account.views import LoginView, LogoutView from netbox.api.views import APIRootView, StatusView from netbox.graphql.schema import schema -from netbox.graphql.views import GraphQLView +from netbox.graphql.views import NetBoxGraphQLView from netbox.plugins.urls import plugin_patterns, plugin_api_patterns from netbox.views import HomeView, StaticMediaFailureView, SearchView, htmx from strawberry.django.views import GraphQLView @@ -61,7 +60,7 @@ _patterns = [ path('api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='api_redocs'), # GraphQL - path('graphql/', GraphQLView.as_view(schema=schema), name='graphql'), + path('graphql/', NetBoxGraphQLView.as_view(schema=schema), name='graphql'), # Serving static media in Django to pipe it through LoginRequiredMiddleware path('media/', serve, {'document_root': settings.MEDIA_ROOT}), diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index ce2817777..384d67707 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -1,5 +1,6 @@ import inspect import json +import strawberry_django from django.conf import settings from django.contrib.auth import get_user_model @@ -18,7 +19,7 @@ from .base import ModelTestCase from .utils import disable_warnings from ipam.graphql.types import IPAddressFamilyType - +from strawberry.type import StrawberryList __all__ = ( 'APITestCase', @@ -447,36 +448,26 @@ class APIViewTestCases: # Compile list of fields to include fields_string = '' - for field_name, field in type_class.__dataclass_fields__.items(): + for field in type_class.__strawberry_definition__.fields: # for field_name, field in type_class._meta.fields.items(): - print(f"field_name: {field_name} field: {field}") - is_string_array = False - if type(field.type) is GQLList: - if field.type.of_type is GQLString: - is_string_array = True - elif type(field.type.of_type) is GQLNonNull and field.type.of_type.of_type is GQLString: - is_string_array = True + print(f"field_name: {field.name} type: {field.type}") - if type(field) is GQLDynamic: + if type(field.type) is StrawberryList: + fields_string += f'{field.name} {{ id }}\n' + elif field.type is strawberry_django.fields.types.DjangoModelType: # Dynamic fields must specify a subselection - fields_string += f'{field_name} {{ id }}\n' + fields_string += f'{field.name} {{ id }}\n' # TODO: Improve field detection logic to avoid nested ArrayFields - elif field_name == 'extra_choices': + elif field.name == 'extra_choices': continue - elif inspect.isclass(field.type) and issubclass(field.type, GQLUnion): - # Union types dont' have an id or consistent values - continue - elif type(field.type) is GQLList and inspect.isclass(field.type.of_type) and issubclass(field.type.of_type, GQLUnion): - # Union types dont' have an id or consistent values - continue - elif type(field.type) is GQLList and not is_string_array: - # TODO: Come up with something more elegant - # Temporary hack to support automated testing of reverse generic relations - fields_string += f'{field_name} {{ id }}\n' + # elif type(field.type) is GQLList and not is_string_array: + # # TODO: Come up with something more elegant + # # Temporary hack to support automated testing of reverse generic relations + # fields_string += f'{field_name} {{ id }}\n' elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType): - fields_string += f'{field_name} {{ value, label }}\n' + fields_string += f'{field.name} {{ value, label }}\n' else: - fields_string += f'{field_name}\n' + fields_string += f'{field.name}\n' query = f""" {{ @@ -486,6 +477,7 @@ class APIViewTestCases: }} """ + print(query) return query @override_settings(LOGIN_REQUIRED=True) @@ -498,6 +490,7 @@ class APIViewTestCases: # Non-authenticated requests should fail with disable_warnings('django.request'): + print(f"url: {url}") self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) # Add object-level permission