From 47ac506d5cb9db3605e814f4534e27ace58ab6ee Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Mon, 27 Oct 2025 16:32:09 -0400 Subject: [PATCH] Add a test to validate versioned GraphQL types --- netbox/netbox/tests/test_graphql.py | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/netbox/netbox/tests/test_graphql.py b/netbox/netbox/tests/test_graphql.py index ca231526f..f5d69b03e 100644 --- a/netbox/netbox/tests/test_graphql.py +++ b/netbox/netbox/tests/test_graphql.py @@ -1,12 +1,15 @@ import json +import strawberry from django.test import override_settings from django.urls import reverse from rest_framework import status +from strawberry.types.lazy_type import LazyType from core.models import ObjectType from dcim.choices import LocationStatusChoices from dcim.models import Site, Location +from netbox.graphql.schema import QueryV1, QueryV2 from users.models import ObjectPermission from utilities.testing import disable_warnings, APITestCase, TestCase @@ -45,6 +48,53 @@ class GraphQLTestCase(TestCase): class GraphQLAPITestCase(APITestCase): + def test_versioned_types(self): + """ + Check that the GraphQL types defined for each version of the schema (V1 and V2) are correct. + """ + schemas = ( + (1, QueryV1), + (2, QueryV2), + ) + + def _get_class_name(field): + try: + if type(field.type) is strawberry.types.base.StrawberryList: + # Skip scalars + if field.type.of_type in (str, int): + return + if type(field.type.of_type) is LazyType: + return field.type.of_type.type_name + return field.type.of_type.__name__ + if hasattr(field.type, 'name'): + return field.type.__name__ + except AttributeError: + # Unknown field type + return + + def _check_version(class_name, version): + if version == 1: + self.assertTrue(class_name.endswith('V1'), f"{class_name} (v1) is not a V1 type") + elif version == 2: + self.assertFalse(class_name.endswith('V1'), f"{class_name} (v2) is a V1 type") + + for version, query in schemas: + schema = strawberry.Schema(query=query) + query_type = schema.get_type_by_name(query.__name__) + + # Iterate through root fields + for field in query_type.fields: + # Check for V1 suffix on class names + if type_class := _get_class_name(field): + _check_version(type_class, version) + + # Iterate through nested fields + subquery_type = schema.get_type_by_name(type_class) + for subfield in subquery_type.fields: + # Check for V1 suffix on class names + if type_class := _get_class_name(subfield): + _check_version(type_class, version) + @override_settings(LOGIN_REQUIRED=True) def test_graphql_filter_objects(self): """