Change attr_type from list to str for MultipleChoiceFilter (#17638)

This commit is contained in:
Jeremy Stretch 2024-10-03 13:24:00 -04:00 committed by GitHub
parent 648aeaaf14
commit f11dc00fae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 9 deletions

View File

@ -271,7 +271,7 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM
class Meta: class Meta:
model = Location model = Location
fields = ('id', 'name', 'slug', 'status', 'facility', 'description') fields = ('id', 'name', 'slug', 'facility', 'description')
def search(self, queryset, name, value): def search(self, queryset, name, value):
if not value.strip(): if not value.strip():

View File

@ -1,11 +1,12 @@
from functools import partial, partialmethod, wraps from functools import partialmethod
from typing import List from typing import List
import django_filters import django_filters
import strawberry import strawberry
import strawberry_django import strawberry_django
from django.core.exceptions import FieldDoesNotExist, ValidationError from django.core.exceptions import FieldDoesNotExist
from strawberry import auto from strawberry import auto
from ipam.fields import ASNField from ipam.fields import ASNField
from netbox.graphql.scalars import BigInt from netbox.graphql.scalars import BigInt
from utilities.fields import ColorField, CounterCacheField from utilities.fields import ColorField, CounterCacheField
@ -108,8 +109,7 @@ def map_strawberry_type(field):
elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter): elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
pass pass
elif issubclass(type(field), django_filters.MultipleChoiceFilter): elif issubclass(type(field), django_filters.MultipleChoiceFilter):
should_create_function = True attr_type = str | None
attr_type = List[str] | None
elif issubclass(type(field), django_filters.TypedChoiceFilter): elif issubclass(type(field), django_filters.TypedChoiceFilter):
pass pass
elif issubclass(type(field), django_filters.ChoiceFilter): elif issubclass(type(field), django_filters.ChoiceFilter):

View File

@ -5,8 +5,8 @@ from django.urls import reverse
from rest_framework import status from rest_framework import status
from core.models import ObjectType from core.models import ObjectType
from dcim.choices import LocationStatusChoices
from dcim.models import Site, Location from dcim.models import Site, Location
from ipam.models import ASN, RIR
from users.models import ObjectPermission from users.models import ObjectPermission
from utilities.testing import disable_warnings, APITestCase, TestCase from utilities.testing import disable_warnings, APITestCase, TestCase
@ -53,10 +53,27 @@ class GraphQLAPITestCase(APITestCase):
sites = ( sites = (
Site(name='Site 1', slug='site-1'), Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'), Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3'),
) )
Site.objects.bulk_create(sites) Site.objects.bulk_create(sites)
Location.objects.create(site=sites[0], name='Location 1', slug='location-1'), Location.objects.create(
Location.objects.create(site=sites[1], name='Location 2', slug='location-2'), site=sites[0],
name='Location 1',
slug='location-1',
status=LocationStatusChoices.STATUS_PLANNED
),
Location.objects.create(
site=sites[1],
name='Location 2',
slug='location-2',
status=LocationStatusChoices.STATUS_STAGING
),
Location.objects.create(
site=sites[1],
name='Location 3',
slug='location-3',
status=LocationStatusChoices.STATUS_ACTIVE
),
# Add object-level permission # Add object-level permission
obj_perm = ObjectPermission( obj_perm = ObjectPermission(
@ -68,8 +85,9 @@ class GraphQLAPITestCase(APITestCase):
obj_perm.object_types.add(ObjectType.objects.get_for_model(Location)) obj_perm.object_types.add(ObjectType.objects.get_for_model(Location))
obj_perm.object_types.add(ObjectType.objects.get_for_model(Site)) obj_perm.object_types.add(ObjectType.objects.get_for_model(Site))
# A valid request should return the filtered list
url = reverse('graphql') url = reverse('graphql')
# A valid request should return the filtered list
query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}' query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}'
response = self.client.post(url, data={'query': query}, format="json", **self.header) response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
@ -78,6 +96,21 @@ class GraphQLAPITestCase(APITestCase):
self.assertEqual(len(data['data']['location_list']), 1) self.assertEqual(len(data['data']['location_list']), 1)
self.assertIsNotNone(data['data']['location_list'][0]['site']) self.assertIsNotNone(data['data']['location_list'][0]['site'])
# Test OR logic
query = """{
location_list( filters: {
status: \"""" + LocationStatusChoices.STATUS_PLANNED + """\",
OR: {status: \"""" + LocationStatusChoices.STATUS_STAGING + """\"}
}) {
id site {id}
}
}"""
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertEqual(len(data['data']['location_list']), 2)
# An invalid request should return an empty list # An invalid request should return an empty list
query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID
response = self.client.post(url, data={'query': query}, format="json", **self.header) response = self.client.post(url, data={'query': query}, format="json", **self.header)