Add GraphQL for virtualization

This commit is contained in:
jeremystretch 2021-06-25 15:31:43 -04:00
parent 881b18f6d0
commit 605b7c5b3e
8 changed files with 100 additions and 18 deletions

View File

@ -1,11 +0,0 @@
import graphene
from graphene_django.converter import convert_django_field
from ipam.fields import IPAddressField, IPNetworkField
@convert_django_field.register(IPAddressField)
@convert_django_field.register(IPNetworkField)
def convert_field_to_string(field, registry=None):
# TODO: Update to use get_django_field_description under django_graphene v3.0
return graphene.String(description=field.help_text, required=not field.null)

View File

@ -2,6 +2,9 @@ import graphene
from graphene_django.converter import convert_django_field from graphene_django.converter import convert_django_field
from taggit.managers import TaggableManager from taggit.managers import TaggableManager
from dcim.fields import MACAddressField
from ipam.fields import IPAddressField, IPNetworkField
@convert_django_field.register(TaggableManager) @convert_django_field.register(TaggableManager)
def convert_field_to_tags_list(field, registry=None): def convert_field_to_tags_list(field, registry=None):
@ -9,3 +12,11 @@ def convert_field_to_tags_list(field, registry=None):
Register conversion handler for django-taggit's TaggableManager Register conversion handler for django-taggit's TaggableManager
""" """
return graphene.List(graphene.String) return graphene.List(graphene.String)
@convert_django_field.register(IPAddressField)
@convert_django_field.register(IPNetworkField)
@convert_django_field.register(MACAddressField)
def convert_field_to_string(field, registry=None):
# TODO: Update to use get_django_field_description under django_graphene v3.0
return graphene.String(description=field.help_text, required=not field.null)

View File

@ -4,6 +4,7 @@ from circuits.graphql.schema import CircuitsQuery
from extras.graphql.schema import ExtrasQuery from extras.graphql.schema import ExtrasQuery
from ipam.graphql.schema import IPAMQuery from ipam.graphql.schema import IPAMQuery
from tenancy.graphql.schema import TenancyQuery from tenancy.graphql.schema import TenancyQuery
from virtualization.graphql.schema import VirtualizationQuery
class Query( class Query(
@ -11,6 +12,7 @@ class Query(
ExtrasQuery, ExtrasQuery,
IPAMQuery, IPAMQuery,
TenancyQuery, TenancyQuery,
VirtualizationQuery,
graphene.ObjectType graphene.ObjectType
): ):
pass pass

View File

@ -425,10 +425,16 @@ class APIViewTestCases:
class GraphQLTestCase(APITestCase): class GraphQLTestCase(APITestCase):
def _get_graphql_base_name(self, plural=False):
if plural:
return getattr(self, 'graphql_base_name_plural',
self.model._meta.verbose_name_plural.lower().replace(' ', '_'))
return getattr(self, 'graphql_base_name', self.model._meta.verbose_name.lower().replace(' ', '_'))
@override_settings(LOGIN_REQUIRED=True) @override_settings(LOGIN_REQUIRED=True)
def test_graphql_get_object(self): def test_graphql_get_object(self):
url = reverse('graphql') url = reverse('graphql')
object_type = self.model._meta.verbose_name.lower().replace(' ', '_') object_type = self._get_graphql_base_name()
object_id = self._get_queryset().first().pk object_id = self._get_queryset().first().pk
query = f""" query = f"""
{{ {{
@ -459,7 +465,7 @@ class APIViewTestCases:
@override_settings(LOGIN_REQUIRED=True) @override_settings(LOGIN_REQUIRED=True)
def test_graphql_list_objects(self): def test_graphql_list_objects(self):
url = reverse('graphql') url = reverse('graphql')
object_type = self.model._meta.verbose_name_plural.lower().replace(' ', '_') object_type = self._get_graphql_base_name(plural=True)
query = f""" query = f"""
{{ {{
{object_type} {{ {object_type} {{

View File

@ -0,0 +1,21 @@
import graphene
from netbox.graphql.fields import ObjectField, ObjectListField
from .types import *
class VirtualizationQuery(graphene.ObjectType):
cluster = ObjectField(ClusterType)
clusters = ObjectListField(ClusterType)
cluster_group = ObjectField(ClusterGroupType)
cluster_groups = ObjectListField(ClusterGroupType)
cluster_type = ObjectField(ClusterTypeType)
cluster_types = ObjectListField(ClusterTypeType)
virtual_machine = ObjectField(VirtualMachineType)
virtual_machines = ObjectListField(VirtualMachineType)
vm_interface = ObjectField(VMInterfaceType)
vm_interfaces = ObjectListField(VMInterfaceType)

View File

@ -0,0 +1,50 @@
from virtualization import filtersets, models
from netbox.graphql.types import ObjectType, TaggedObjectType
__all__ = (
'ClusterType',
'ClusterGroupType',
'ClusterTypeType',
'VirtualMachineType',
'VMInterfaceType',
)
class ClusterType(TaggedObjectType):
class Meta:
model = models.Cluster
fields = '__all__'
filterset_class = filtersets.ClusterFilterSet
class ClusterGroupType(ObjectType):
class Meta:
model = models.ClusterGroup
fields = '__all__'
filterset_class = filtersets.ClusterGroupFilterSet
class ClusterTypeType(ObjectType):
class Meta:
model = models.ClusterType
fields = '__all__'
filterset_class = filtersets.ClusterTypeFilterSet
class VirtualMachineType(TaggedObjectType):
class Meta:
model = models.VirtualMachine
fields = '__all__'
filterset_class = filtersets.VirtualMachineFilterSet
class VMInterfaceType(ObjectType):
class Meta:
model = models.VMInterface
fields = '__all__'
filterset_class = filtersets.VMInterfaceFilterSet

View File

@ -17,7 +17,7 @@ class AppTest(APITestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class ClusterTypeTest(APIViewTestCases.APIViewTestCase): class ClusterTypeTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
model = ClusterType model = ClusterType
brief_fields = ['cluster_count', 'display', 'id', 'name', 'slug', 'url'] brief_fields = ['cluster_count', 'display', 'id', 'name', 'slug', 'url']
create_data = [ create_data = [
@ -49,7 +49,7 @@ class ClusterTypeTest(APIViewTestCases.APIViewTestCase):
ClusterType.objects.bulk_create(cluster_types) ClusterType.objects.bulk_create(cluster_types)
class ClusterGroupTest(APIViewTestCases.APIViewTestCase): class ClusterGroupTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
model = ClusterGroup model = ClusterGroup
brief_fields = ['cluster_count', 'display', 'id', 'name', 'slug', 'url'] brief_fields = ['cluster_count', 'display', 'id', 'name', 'slug', 'url']
create_data = [ create_data = [
@ -81,7 +81,7 @@ class ClusterGroupTest(APIViewTestCases.APIViewTestCase):
ClusterGroup.objects.bulk_create(cluster_Groups) ClusterGroup.objects.bulk_create(cluster_Groups)
class ClusterTest(APIViewTestCases.APIViewTestCase): class ClusterTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
model = Cluster model = Cluster
brief_fields = ['display', 'id', 'name', 'url', 'virtualmachine_count'] brief_fields = ['display', 'id', 'name', 'url', 'virtualmachine_count']
bulk_update_data = { bulk_update_data = {
@ -129,7 +129,7 @@ class ClusterTest(APIViewTestCases.APIViewTestCase):
] ]
class VirtualMachineTest(APIViewTestCases.APIViewTestCase): class VirtualMachineTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
model = VirtualMachine model = VirtualMachine
brief_fields = ['display', 'id', 'name', 'url'] brief_fields = ['display', 'id', 'name', 'url']
bulk_update_data = { bulk_update_data = {
@ -205,13 +205,16 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
class VMInterfaceTest(APIViewTestCases.APIViewTestCase): class VMInterfaceTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
model = VMInterface model = VMInterface
brief_fields = ['display', 'id', 'name', 'url', 'virtual_machine'] brief_fields = ['display', 'id', 'name', 'url', 'virtual_machine']
bulk_update_data = { bulk_update_data = {
'description': 'New description', 'description': 'New description',
} }
graphql_base_name = 'vm_interface'
graphql_base_name_plural = 'vm_interfaces'
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):