mirror of
https://github.com/netbox-community/netbox.git
synced 2026-02-05 06:46:25 -06:00
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
import strawberry
|
||||
from strawberry.types.unset import UNSET
|
||||
from strawberry_django.pagination import _QS, apply
|
||||
|
||||
__all__ = (
|
||||
'OffsetPaginationInfo',
|
||||
'OffsetPaginationInput',
|
||||
'apply_pagination',
|
||||
)
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class OffsetPaginationInfo:
|
||||
offset: int = 0
|
||||
limit: int | None = UNSET
|
||||
start: int | None = UNSET
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class OffsetPaginationInput(OffsetPaginationInfo):
|
||||
"""
|
||||
Customized implementation of OffsetPaginationInput to support cursor-based pagination.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def apply_pagination(
|
||||
self,
|
||||
queryset: _QS,
|
||||
pagination: OffsetPaginationInput | None = None,
|
||||
*,
|
||||
related_field_id: str | None = None,
|
||||
) -> _QS:
|
||||
"""
|
||||
Replacement for the `apply_pagination()` method on StrawberryDjangoField to support cursor-based pagination.
|
||||
"""
|
||||
if pagination is not None and pagination.start not in (None, UNSET):
|
||||
if pagination.offset:
|
||||
raise ValueError('Cannot specify both `start` and `offset` in pagination.')
|
||||
if pagination.start < 0:
|
||||
raise ValueError('`start` must be greater than or equal to zero.')
|
||||
|
||||
# Filter the queryset to include only records with a primary key greater than or equal to the start value,
|
||||
# and force ordering by primary key to ensure consistent pagination across all records.
|
||||
queryset = queryset.filter(pk__gte=pagination.start).order_by('pk')
|
||||
|
||||
# Ignore `offset` when `start` is set
|
||||
pagination.offset = 0
|
||||
|
||||
return apply(pagination, queryset, related_field_id=related_field_id)
|
||||
@@ -12,10 +12,13 @@ from django.core.validators import URLValidator
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.utils import field_mapping
|
||||
from strawberry_django import pagination
|
||||
from strawberry_django.fields.field import StrawberryDjangoField
|
||||
|
||||
from core.exceptions import IncompatiblePluginError
|
||||
from netbox.config import PARAMS as CONFIG_PARAMS
|
||||
from netbox.constants import RQ_QUEUE_DEFAULT, RQ_QUEUE_HIGH, RQ_QUEUE_LOW
|
||||
from netbox.graphql.pagination import OffsetPaginationInput, apply_pagination
|
||||
from netbox.plugins import PluginConfig
|
||||
from netbox.registry import registry
|
||||
import storages.utils # type: ignore
|
||||
@@ -33,6 +36,12 @@ from .monkey import get_unique_validators
|
||||
# Override DRF's get_unique_validators() function with our own (see bug #19302)
|
||||
field_mapping.get_unique_validators = get_unique_validators
|
||||
|
||||
# Override strawberry-django's OffsetPaginationInput class to add the `start` parameter
|
||||
pagination.OffsetPaginationInput = OffsetPaginationInput
|
||||
|
||||
# Patch StrawberryDjangoField to use our custom `apply_pagination()` method with support for cursor-based pagination
|
||||
StrawberryDjangoField.apply_pagination = apply_pagination
|
||||
|
||||
|
||||
#
|
||||
# Environment setup
|
||||
|
||||
@@ -4,10 +4,8 @@ from django.test import override_settings
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
|
||||
from core.models import ObjectType
|
||||
from dcim.choices import LocationStatusChoices
|
||||
from dcim.models import Site, Location
|
||||
from users.models import ObjectPermission
|
||||
from utilities.testing import disable_warnings, APITestCase, TestCase
|
||||
|
||||
|
||||
@@ -45,17 +43,28 @@ class GraphQLTestCase(TestCase):
|
||||
|
||||
class GraphQLAPITestCase(APITestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
sites = (
|
||||
Site(name='Site 1', slug='site-1'),
|
||||
Site(name='Site 2', slug='site-2'),
|
||||
Site(name='Site 3', slug='site-3'),
|
||||
Site(name='Site 4', slug='site-4'),
|
||||
Site(name='Site 5', slug='site-5'),
|
||||
Site(name='Site 6', slug='site-6'),
|
||||
Site(name='Site 7', slug='site-7'),
|
||||
)
|
||||
Site.objects.bulk_create(sites)
|
||||
|
||||
@override_settings(LOGIN_REQUIRED=True)
|
||||
def test_graphql_filter_objects(self):
|
||||
"""
|
||||
Test the operation of filters for GraphQL API requests.
|
||||
"""
|
||||
sites = (
|
||||
Site(name='Site 1', slug='site-1'),
|
||||
Site(name='Site 2', slug='site-2'),
|
||||
Site(name='Site 3', slug='site-3'),
|
||||
)
|
||||
Site.objects.bulk_create(sites)
|
||||
self.add_permissions('dcim.view_site', 'dcim.view_location')
|
||||
url = reverse('graphql')
|
||||
|
||||
sites = Site.objects.all()[:3]
|
||||
Location.objects.create(
|
||||
site=sites[0],
|
||||
name='Location 1',
|
||||
@@ -75,18 +84,6 @@ class GraphQLAPITestCase(APITestCase):
|
||||
status=LocationStatusChoices.STATUS_ACTIVE
|
||||
),
|
||||
|
||||
# Add object-level permission
|
||||
obj_perm = ObjectPermission(
|
||||
name='Test permission',
|
||||
actions=['view']
|
||||
)
|
||||
obj_perm.save()
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(Location))
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(Site))
|
||||
|
||||
url = reverse('graphql')
|
||||
|
||||
# A valid request should return the filtered list
|
||||
query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}'
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
@@ -133,10 +130,136 @@ class GraphQLAPITestCase(APITestCase):
|
||||
self.assertEqual(len(data['data']['location_list']), 0)
|
||||
|
||||
# Removing the permissions from location should result in an empty locations list
|
||||
obj_perm.object_types.remove(ObjectType.objects.get_for_model(Location))
|
||||
self.remove_permissions('dcim.view_location')
|
||||
query = '{site(id: ' + str(sites[0].pk) + ') {id locations {id}}}'
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site']['locations']), 0)
|
||||
|
||||
def test_offset_pagination(self):
|
||||
self.add_permissions('dcim.view_site')
|
||||
url = reverse('graphql')
|
||||
|
||||
# Test `limit` only
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {limit: 3}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site_list']), 3)
|
||||
self.assertEqual(data['data']['site_list'][0]['name'], 'Site 1')
|
||||
self.assertEqual(data['data']['site_list'][1]['name'], 'Site 2')
|
||||
self.assertEqual(data['data']['site_list'][2]['name'], 'Site 3')
|
||||
|
||||
# Test `offset` only
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {offset: 3}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site_list']), 4)
|
||||
self.assertEqual(data['data']['site_list'][0]['name'], 'Site 4')
|
||||
self.assertEqual(data['data']['site_list'][1]['name'], 'Site 5')
|
||||
self.assertEqual(data['data']['site_list'][2]['name'], 'Site 6')
|
||||
self.assertEqual(data['data']['site_list'][3]['name'], 'Site 7')
|
||||
|
||||
# Test `offset` & `limit`
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {offset: 3, limit: 3}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site_list']), 3)
|
||||
self.assertEqual(data['data']['site_list'][0]['name'], 'Site 4')
|
||||
self.assertEqual(data['data']['site_list'][1]['name'], 'Site 5')
|
||||
self.assertEqual(data['data']['site_list'][2]['name'], 'Site 6')
|
||||
|
||||
def test_cursor_pagination(self):
|
||||
self.add_permissions('dcim.view_site')
|
||||
url = reverse('graphql')
|
||||
|
||||
# Page 1
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {start: 0, limit: 3}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site_list']), 3)
|
||||
self.assertEqual(data['data']['site_list'][0]['name'], 'Site 1')
|
||||
self.assertEqual(data['data']['site_list'][1]['name'], 'Site 2')
|
||||
self.assertEqual(data['data']['site_list'][2]['name'], 'Site 3')
|
||||
|
||||
# Page 2
|
||||
start_id = int(data['data']['site_list'][-1]['id']) + 1
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {start: """ + str(start_id) + """, limit: 3}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site_list']), 3)
|
||||
self.assertEqual(data['data']['site_list'][0]['name'], 'Site 4')
|
||||
self.assertEqual(data['data']['site_list'][1]['name'], 'Site 5')
|
||||
self.assertEqual(data['data']['site_list'][2]['name'], 'Site 6')
|
||||
|
||||
# Page 3
|
||||
start_id = int(data['data']['site_list'][-1]['id']) + 1
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {start: """ + str(start_id) + """, limit: 3}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data']['site_list']), 1)
|
||||
self.assertEqual(data['data']['site_list'][0]['name'], 'Site 7')
|
||||
|
||||
def test_pagination_conflict(self):
|
||||
url = reverse('graphql')
|
||||
query = """
|
||||
{
|
||||
site_list(pagination: {start: 1, offset: 1}) {
|
||||
id name
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = self.client.post(url, data={'query': query}, format='json', **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertIn('errors', data)
|
||||
self.assertEqual(data['errors'][0]['message'], 'Cannot specify both `start` and `offset` in pagination.')
|
||||
|
||||
Reference in New Issue
Block a user