Change attr_type from list to str for MultipleChoiceFilter (#17638)

This commit is contained in:
Jeremy Stretch 2024-10-03 13:24:00 -04:00 committed by GitHub
parent 648aeaaf14
commit f11dc00fae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 9 deletions

View File

@ -271,7 +271,7 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM
class Meta:
model = Location
fields = ('id', 'name', 'slug', 'status', 'facility', 'description')
fields = ('id', 'name', 'slug', 'facility', 'description')
def search(self, queryset, name, value):
if not value.strip():

View File

@ -1,11 +1,12 @@
from functools import partial, partialmethod, wraps
from functools import partialmethod
from typing import List
import django_filters
import strawberry
import strawberry_django
from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.core.exceptions import FieldDoesNotExist
from strawberry import auto
from ipam.fields import ASNField
from netbox.graphql.scalars import BigInt
from utilities.fields import ColorField, CounterCacheField
@ -108,8 +109,7 @@ def map_strawberry_type(field):
elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
pass
elif issubclass(type(field), django_filters.MultipleChoiceFilter):
should_create_function = True
attr_type = List[str] | None
attr_type = str | None
elif issubclass(type(field), django_filters.TypedChoiceFilter):
pass
elif issubclass(type(field), django_filters.ChoiceFilter):

View File

@ -5,8 +5,8 @@ from django.urls import reverse
from rest_framework import status
from core.models import ObjectType
from dcim.choices import LocationStatusChoices
from dcim.models import Site, Location
from ipam.models import ASN, RIR
from users.models import ObjectPermission
from utilities.testing import disable_warnings, APITestCase, TestCase
@ -53,10 +53,27 @@ class GraphQLAPITestCase(APITestCase):
sites = (
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3'),
)
Site.objects.bulk_create(sites)
Location.objects.create(site=sites[0], name='Location 1', slug='location-1'),
Location.objects.create(site=sites[1], name='Location 2', slug='location-2'),
Location.objects.create(
site=sites[0],
name='Location 1',
slug='location-1',
status=LocationStatusChoices.STATUS_PLANNED
),
Location.objects.create(
site=sites[1],
name='Location 2',
slug='location-2',
status=LocationStatusChoices.STATUS_STAGING
),
Location.objects.create(
site=sites[1],
name='Location 3',
slug='location-3',
status=LocationStatusChoices.STATUS_ACTIVE
),
# Add object-level permission
obj_perm = ObjectPermission(
@ -68,8 +85,9 @@ class GraphQLAPITestCase(APITestCase):
obj_perm.object_types.add(ObjectType.objects.get_for_model(Location))
obj_perm.object_types.add(ObjectType.objects.get_for_model(Site))
# A valid request should return the filtered list
url = reverse('graphql')
# A valid request should return the filtered list
query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}'
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
@ -78,6 +96,21 @@ class GraphQLAPITestCase(APITestCase):
self.assertEqual(len(data['data']['location_list']), 1)
self.assertIsNotNone(data['data']['location_list'][0]['site'])
# Test OR logic
query = """{
location_list( filters: {
status: \"""" + LocationStatusChoices.STATUS_PLANNED + """\",
OR: {status: \"""" + LocationStatusChoices.STATUS_STAGING + """\"}
}) {
id site {id}
}
}"""
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertEqual(len(data['data']['location_list']), 2)
# An invalid request should return an empty list
query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID
response = self.client.post(url, data={'query': query}, format="json", **self.header)