mirror of
https://github.com/netbox-community/netbox.git
synced 2025-12-18 19:32:24 -06:00
9856 Replace graphene with Strawberry (#15141)
* 9856 base strawberry integration * 9856 user and group * 9856 user and circuits base * 9856 extras and mixins * 9856 fk * 9856 update strawberry version * 9856 update imports * 9856 compatability fixes * 9856 compatability fixes * 9856 update strawberry types * 9856 update strawberry types * 9856 core schema * 9856 dcim schema * 9856 extras schema * 9856 ipam and tenant schema * 9856 virtualization, vpn, wireless schema * 9856 fix old decorator * 9856 cleanup * 9856 cleanup * 9856 fixes to circuits type specifiers * 9856 fixes to circuits type specifiers * 9856 update types * 9856 GFK working * 9856 GFK working * 9856 _name * 9856 misc fixes * 9856 type updates * 9856 _name to types * 9856 update types * 9856 update types * 9856 update types * 9856 update types * 9856 update types * 9856 update types * 9856 update types * 9856 update types * 9856 update types * 9856 GraphQLView * 9856 GraphQLView * 9856 fix OrganizationalObjectType * 9856 single item query for schema * 9856 circuits graphql tests working * 9856 test fixes * 9856 test fixes * 9856 test fixes * 9856 test fix vpn * 9856 test fixes * 9856 test fixes * 9856 test fixes * 9856 circuits test sans DjangoModelType * 9856 core test sans DjangoModelType * 9856 temp checkin * 9856 fix extas FK * 9856 fix tenancy FK * 9856 fix virtualization FK * 9856 fix vpn FK * 9856 fix wireless FK * 9856 fix ipam FK * 9856 fix partial dcim FK * 9856 fix dcim FK * 9856 fix virtualization FK * 9856 fix tests / remove debug code * 9856 fix test imagefield * 9856 cleanup graphene * 9856 fix plugin schema * 9856 fix requirements * 9856 fix requirements * 9856 fix docs * 9856 fix docs * 9856 temp fix tests * 9856 first filterset * 9856 first filterset * 9856 fix tests * 9856 fix tests * 9856 working auto filter generation * 9856 filter types * 9856 filter types * 9856 filter types * 9856 fix graphiql test * 9856 fix counter fields and merge feature * 9856 temp fix tests * 9856 fix tests * 9856 fix tenancy, ipam filter definitions * 9856 cleanup * 9856 cleanup * 9856 cleanup * 9856 review changes * 9856 review changes * 9856 review changes * 9856 fix base-requirements * 9856 add wrapper to graphiql * 9856 remove old graphiql debug toolbar * 9856 review changes * 9856 update strawberry * 9856 remove superfluous check --------- Co-authored-by: Jeremy Stretch <jstretch@netboxlabs.com>
This commit is contained in:
@@ -1,252 +0,0 @@
|
||||
import functools
|
||||
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db.models import ForeignKey
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields.reverse_related import ManyToOneRel
|
||||
from graphene import InputObjectType
|
||||
from graphene.types.generic import GenericScalar
|
||||
from graphene.types.resolver import default_resolver
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphql import GraphQLResolveInfo, GraphQLSchema
|
||||
from graphql.execution.execute import get_field_def
|
||||
from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode, VariableNode
|
||||
from graphql.pyutils import Path
|
||||
from graphql.type.definition import GraphQLInterfaceType, GraphQLUnionType
|
||||
|
||||
__all__ = (
|
||||
'gql_query_optimizer',
|
||||
)
|
||||
|
||||
|
||||
def gql_query_optimizer(queryset, info, **options):
|
||||
return QueryOptimizer(info).optimize(queryset)
|
||||
|
||||
|
||||
class QueryOptimizer(object):
|
||||
def __init__(self, info, **options):
|
||||
self.root_info = info
|
||||
|
||||
def optimize(self, queryset):
|
||||
info = self.root_info
|
||||
field_def = get_field_def(info.schema, info.parent_type, info.field_nodes[0])
|
||||
|
||||
field_names = self._optimize_gql_selections(
|
||||
self._get_type(field_def),
|
||||
info.field_nodes[0],
|
||||
)
|
||||
|
||||
qs = queryset.prefetch_related(*field_names)
|
||||
return qs
|
||||
|
||||
def _get_type(self, field_def):
|
||||
a_type = field_def.type
|
||||
while hasattr(a_type, "of_type"):
|
||||
a_type = a_type.of_type
|
||||
return a_type
|
||||
|
||||
def _get_graphql_schema(self, schema):
|
||||
if isinstance(schema, GraphQLSchema):
|
||||
return schema
|
||||
else:
|
||||
return schema.graphql_schema
|
||||
|
||||
def _get_possible_types(self, graphql_type):
|
||||
if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
|
||||
graphql_schema = self._get_graphql_schema(self.root_info.schema)
|
||||
return graphql_schema.get_possible_types(graphql_type)
|
||||
else:
|
||||
return (graphql_type,)
|
||||
|
||||
def _get_base_model(self, graphql_types):
|
||||
models = tuple(t.graphene_type._meta.model for t in graphql_types)
|
||||
for model in models:
|
||||
if all(issubclass(m, model) for m in models):
|
||||
return model
|
||||
return None
|
||||
|
||||
def handle_inline_fragment(self, selection, schema, possible_types, field_names):
|
||||
fragment_type_name = selection.type_condition.name.value
|
||||
graphql_schema = self._get_graphql_schema(schema)
|
||||
fragment_type = graphql_schema.get_type(fragment_type_name)
|
||||
fragment_possible_types = self._get_possible_types(fragment_type)
|
||||
for fragment_possible_type in fragment_possible_types:
|
||||
fragment_model = fragment_possible_type.graphene_type._meta.model
|
||||
parent_model = self._get_base_model(possible_types)
|
||||
if not parent_model:
|
||||
continue
|
||||
path_from_parent = fragment_model._meta.get_path_from_parent(parent_model)
|
||||
select_related_name = LOOKUP_SEP.join(p.join_field.name for p in path_from_parent)
|
||||
if not select_related_name:
|
||||
continue
|
||||
sub_field_names = self._optimize_gql_selections(
|
||||
fragment_possible_type,
|
||||
selection,
|
||||
)
|
||||
field_names.append(select_related_name)
|
||||
return
|
||||
|
||||
def handle_fragment_spread(self, field_names, name, field_type):
|
||||
fragment = self.root_info.fragments[name]
|
||||
sub_field_names = self._optimize_gql_selections(
|
||||
field_type,
|
||||
fragment,
|
||||
)
|
||||
|
||||
def _optimize_gql_selections(self, field_type, field_ast):
|
||||
field_names = []
|
||||
selection_set = field_ast.selection_set
|
||||
if not selection_set:
|
||||
return field_names
|
||||
optimized_fields_by_model = {}
|
||||
schema = self.root_info.schema
|
||||
graphql_schema = self._get_graphql_schema(schema)
|
||||
graphql_type = graphql_schema.get_type(field_type.name)
|
||||
|
||||
possible_types = self._get_possible_types(graphql_type)
|
||||
for selection in selection_set.selections:
|
||||
if isinstance(selection, InlineFragmentNode):
|
||||
self.handle_inline_fragment(selection, schema, possible_types, field_names)
|
||||
else:
|
||||
name = selection.name.value
|
||||
if isinstance(selection, FragmentSpreadNode):
|
||||
self.handle_fragment_spread(field_names, name, field_type)
|
||||
else:
|
||||
for possible_type in possible_types:
|
||||
selection_field_def = possible_type.fields.get(name)
|
||||
if not selection_field_def:
|
||||
continue
|
||||
|
||||
graphene_type = possible_type.graphene_type
|
||||
model = getattr(graphene_type._meta, "model", None)
|
||||
if model and name not in optimized_fields_by_model:
|
||||
field_model = optimized_fields_by_model[name] = model
|
||||
if field_model == model:
|
||||
self._optimize_field(
|
||||
field_names,
|
||||
model,
|
||||
selection,
|
||||
selection_field_def,
|
||||
possible_type,
|
||||
)
|
||||
return field_names
|
||||
|
||||
def _get_field_info(self, field_names, model, selection, field_def):
|
||||
name = None
|
||||
model_field = None
|
||||
name = self._get_name_from_resolver(field_def.resolve)
|
||||
if not name and callable(field_def.resolve) and not isinstance(field_def.resolve, functools.partial):
|
||||
name = selection.name.value
|
||||
if name:
|
||||
model_field = self._get_model_field_from_name(model, name)
|
||||
|
||||
return (name, model_field)
|
||||
|
||||
def _optimize_field(self, field_names, model, selection, field_def, parent_type):
|
||||
name, model_field = self._get_field_info(field_names, model, selection, field_def)
|
||||
if model_field:
|
||||
self._optimize_field_by_name(field_names, model, selection, field_def, name, model_field)
|
||||
|
||||
return
|
||||
|
||||
def _optimize_field_by_name(self, field_names, model, selection, field_def, name, model_field):
|
||||
if model_field.many_to_one or model_field.one_to_one:
|
||||
sub_field_names = self._optimize_gql_selections(
|
||||
self._get_type(field_def),
|
||||
selection,
|
||||
)
|
||||
if name not in field_names:
|
||||
field_names.append(name)
|
||||
|
||||
for field in sub_field_names:
|
||||
prefetch_key = f"{name}__{field}"
|
||||
if prefetch_key not in field_names:
|
||||
field_names.append(prefetch_key)
|
||||
|
||||
if model_field.one_to_many or model_field.many_to_many:
|
||||
sub_field_names = self._optimize_gql_selections(
|
||||
self._get_type(field_def),
|
||||
selection,
|
||||
)
|
||||
|
||||
if isinstance(model_field, ManyToOneRel):
|
||||
sub_field_names.append(model_field.field.name)
|
||||
|
||||
field_names.append(name)
|
||||
for field in sub_field_names:
|
||||
prefetch_key = f"{name}__{field}"
|
||||
if prefetch_key not in field_names:
|
||||
field_names.append(prefetch_key)
|
||||
|
||||
return
|
||||
|
||||
def _get_optimization_hints(self, resolver):
|
||||
return getattr(resolver, "optimization_hints", None)
|
||||
|
||||
def _get_value(self, info, value):
|
||||
if isinstance(value, VariableNode):
|
||||
var_name = value.name.value
|
||||
value = info.variable_values.get(var_name)
|
||||
return value
|
||||
elif isinstance(value, InputObjectType):
|
||||
return value.__dict__
|
||||
else:
|
||||
return GenericScalar.parse_literal(value)
|
||||
|
||||
def _get_name_from_resolver(self, resolver):
|
||||
optimization_hints = self._get_optimization_hints(resolver)
|
||||
if optimization_hints:
|
||||
name_fn = optimization_hints.model_field
|
||||
if name_fn:
|
||||
return name_fn()
|
||||
if self._is_resolver_for_id_field(resolver):
|
||||
return "id"
|
||||
elif isinstance(resolver, functools.partial):
|
||||
resolver_fn = resolver
|
||||
if resolver_fn.func != default_resolver:
|
||||
# Some resolvers have the partial function as the second
|
||||
# argument.
|
||||
for arg in resolver_fn.args:
|
||||
if isinstance(arg, (str, functools.partial)):
|
||||
break
|
||||
else:
|
||||
# No suitable instances found, default to first arg
|
||||
arg = resolver_fn.args[0]
|
||||
resolver_fn = arg
|
||||
if isinstance(resolver_fn, functools.partial) and resolver_fn.func == default_resolver:
|
||||
return resolver_fn.args[0]
|
||||
if self._is_resolver_for_id_field(resolver_fn):
|
||||
return "id"
|
||||
return resolver_fn
|
||||
|
||||
def _is_resolver_for_id_field(self, resolver):
|
||||
resolve_id = DjangoObjectType.resolve_id
|
||||
return resolver == resolve_id
|
||||
|
||||
def _get_model_field_from_name(self, model, name):
|
||||
try:
|
||||
return model._meta.get_field(name)
|
||||
except FieldDoesNotExist:
|
||||
descriptor = model.__dict__.get(name)
|
||||
if not descriptor:
|
||||
return None
|
||||
return getattr(descriptor, "rel", None) or getattr(descriptor, "related", None) # Django < 1.9
|
||||
|
||||
def _is_foreign_key_id(self, model_field, name):
|
||||
return isinstance(model_field, ForeignKey) and model_field.name != name and model_field.get_attname() == name
|
||||
|
||||
def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
|
||||
return GraphQLResolveInfo(
|
||||
field_name,
|
||||
field_asts,
|
||||
return_type,
|
||||
parent_type,
|
||||
Path(None, 0, None),
|
||||
schema=self.root_info.schema,
|
||||
fragments=self.root_info.fragments,
|
||||
root_value=self.root_info.root_value,
|
||||
operation=self.root_info.operation,
|
||||
variable_values=self.root_info.variable_values,
|
||||
context=self.root_info.context,
|
||||
is_awaitable=self.root_info.is_awaitable,
|
||||
)
|
||||
@@ -1,12 +1,12 @@
|
||||
import inspect
|
||||
import json
|
||||
import strawberry_django
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.urls import reverse
|
||||
from django.test import override_settings
|
||||
from graphene.types import Dynamic as GQLDynamic, List as GQLList, Union as GQLUnion, String as GQLString, NonNull as GQLNonNull
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
@@ -19,7 +19,10 @@ from .base import ModelTestCase
|
||||
from .utils import disable_warnings
|
||||
|
||||
from ipam.graphql.types import IPAddressFamilyType
|
||||
|
||||
from strawberry.field import StrawberryField
|
||||
from strawberry.lazy_type import LazyType
|
||||
from strawberry.type import StrawberryList, StrawberryOptional
|
||||
from strawberry.union import StrawberryUnion
|
||||
|
||||
__all__ = (
|
||||
'APITestCase',
|
||||
@@ -447,34 +450,34 @@ class APIViewTestCases:
|
||||
|
||||
# Compile list of fields to include
|
||||
fields_string = ''
|
||||
for field_name, field in type_class._meta.fields.items():
|
||||
is_string_array = False
|
||||
if type(field.type) is GQLList:
|
||||
if field.type.of_type is GQLString:
|
||||
is_string_array = True
|
||||
elif type(field.type.of_type) is GQLNonNull and field.type.of_type.of_type is GQLString:
|
||||
is_string_array = True
|
||||
|
||||
if type(field) is GQLDynamic:
|
||||
# Dynamic fields must specify a subselection
|
||||
fields_string += f'{field_name} {{ id }}\n'
|
||||
# TODO: Improve field detection logic to avoid nested ArrayFields
|
||||
elif field_name == 'extra_choices':
|
||||
file_fields = (strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType)
|
||||
for field in type_class.__strawberry_definition__.fields:
|
||||
if (
|
||||
field.type in file_fields or (
|
||||
type(field.type) is StrawberryOptional and field.type.of_type in file_fields
|
||||
)
|
||||
):
|
||||
# image / file fields nullable or not...
|
||||
fields_string += f'{field.name} {{ name }}\n'
|
||||
elif type(field.type) is StrawberryList and type(field.type.of_type) is LazyType:
|
||||
# List of related objects (queryset)
|
||||
fields_string += f'{field.name} {{ id }}\n'
|
||||
elif type(field.type) is StrawberryList and type(field.type.of_type) is StrawberryUnion:
|
||||
# this would require a fragment query
|
||||
continue
|
||||
elif inspect.isclass(field.type) and issubclass(field.type, GQLUnion):
|
||||
# Union types dont' have an id or consistent values
|
||||
elif type(field.type) is StrawberryUnion:
|
||||
# this would require a fragment query
|
||||
continue
|
||||
elif type(field.type) is GQLList and inspect.isclass(field.type.of_type) and issubclass(field.type.of_type, GQLUnion):
|
||||
# Union types dont' have an id or consistent values
|
||||
continue
|
||||
elif type(field.type) is GQLList and not is_string_array:
|
||||
# TODO: Come up with something more elegant
|
||||
# Temporary hack to support automated testing of reverse generic relations
|
||||
fields_string += f'{field_name} {{ id }}\n'
|
||||
elif type(field.type) is StrawberryOptional and type(field.type.of_type) is LazyType:
|
||||
fields_string += f'{field.name} {{ id }}\n'
|
||||
elif hasattr(field, 'is_relation') and field.is_relation:
|
||||
# Note: StrawberryField types do not have is_relation
|
||||
fields_string += f'{field.name} {{ id }}\n'
|
||||
elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType):
|
||||
fields_string += f'{field_name} {{ value, label }}\n'
|
||||
fields_string += f'{field.name} {{ value, label }}\n'
|
||||
else:
|
||||
fields_string += f'{field_name}\n'
|
||||
fields_string += f'{field.name}\n'
|
||||
|
||||
query = f"""
|
||||
{{
|
||||
@@ -496,7 +499,10 @@ class APIViewTestCases:
|
||||
|
||||
# Non-authenticated requests should fail
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
|
||||
header = {
|
||||
'HTTP_ACCEPT': 'application/json',
|
||||
}
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Add object-level permission
|
||||
obj_perm = ObjectPermission(
|
||||
@@ -507,7 +513,7 @@ class APIViewTestCases:
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
response = self.client.post(url, data={'query': query}, **self.header)
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
@@ -521,7 +527,10 @@ class APIViewTestCases:
|
||||
|
||||
# Non-authenticated requests should fail
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
|
||||
header = {
|
||||
'HTTP_ACCEPT': 'application/json',
|
||||
}
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Add object-level permission
|
||||
obj_perm = ObjectPermission(
|
||||
@@ -532,7 +541,7 @@ class APIViewTestCases:
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
response = self.client.post(url, data={'query': query}, **self.header)
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
|
||||
Reference in New Issue
Block a user