mirror of
https://github.com/netbox-community/netbox.git
synced 2025-12-09 01:49:35 -06:00
feat(extras): Add range_contains ORM lookup
Introduce a generic lookup for ArrayField(RangeField) that matches rows where a scalar value is contained by any range in the array (e.g. VLANGroup.vid_ranges). Replace the raw-SQL helper in the VLANGroup FilterSet (`contains_vid`) with the ORM lookup for better maintainability. Add tests for the lookup and the FilterSet behavior. Closes #20497
This commit is contained in:
parent
2abc5ac69a
commit
33d4759871
@ -1,9 +1,39 @@
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.fields.ranges import RangeField
|
||||
from django.db.models import CharField, JSONField, Lookup
|
||||
from django.db.models.fields.json import KeyTextTransform
|
||||
|
||||
from .fields import CachedValueField
|
||||
|
||||
|
||||
class RangeContains(Lookup):
|
||||
"""
|
||||
Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS.
|
||||
|
||||
Usage (ORM):
|
||||
Model.objects.filter(<range_array_field>__range_contains=<scalar>)
|
||||
|
||||
Works with int4range[], int8range[], daterange[], tstzrange[], etc.
|
||||
"""
|
||||
|
||||
lookup_name = 'range_contains'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Compile LHS (the array-of-ranges column/expression) and RHS (scalar)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
|
||||
# Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField
|
||||
field = getattr(self.lhs, 'output_field', None)
|
||||
if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)):
|
||||
raise TypeError('range_contains is only valid for ArrayField(RangeField) columns')
|
||||
|
||||
# Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value
|
||||
sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})"
|
||||
params = lhs_params + rhs_params
|
||||
return sql, params
|
||||
|
||||
|
||||
class Empty(Lookup):
|
||||
"""
|
||||
Filter on whether a string is empty.
|
||||
@ -25,7 +55,7 @@ class JSONEmpty(Lookup):
|
||||
|
||||
A key is considered empty if it is "", null, or does not exist.
|
||||
"""
|
||||
lookup_name = "empty"
|
||||
lookup_name = 'empty'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
|
||||
@ -69,6 +99,7 @@ class NetContainsOrEquals(Lookup):
|
||||
return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
|
||||
|
||||
|
||||
ArrayField.register_lookup(RangeContains)
|
||||
CharField.register_lookup(Empty)
|
||||
JSONField.register_lookup(JSONEmpty)
|
||||
CachedValueField.register_lookup(NetHost)
|
||||
|
||||
@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
|
||||
method='filter_scope'
|
||||
)
|
||||
contains_vid = django_filters.NumberFilter(
|
||||
method='filter_contains_vid'
|
||||
field_name='vid_ranges',
|
||||
lookup_expr='range_contains',
|
||||
)
|
||||
|
||||
class Meta:
|
||||
@ -931,21 +932,6 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
|
||||
scope_id=value
|
||||
)
|
||||
|
||||
def filter_contains_vid(self, queryset, name, value):
|
||||
"""
|
||||
Return all VLANGroups which contain the given VLAN ID.
|
||||
"""
|
||||
table_name = VLANGroup._meta.db_table
|
||||
# TODO: See if this can be optimized without compromising queryset integrity
|
||||
# Expand VLAN ID ranges to query by integer
|
||||
groups = VLANGroup.objects.raw(
|
||||
f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range',
|
||||
params=(value,)
|
||||
)
|
||||
return queryset.filter(
|
||||
pk__in=[g.id for g in groups]
|
||||
)
|
||||
|
||||
|
||||
class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
|
||||
region_id = TreeNodeMultipleChoiceFilter(
|
||||
|
||||
@ -19,7 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin
|
||||
from virtualization.models import VMInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup
|
||||
from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup
|
||||
from circuits.graphql.filters import ProviderFilter
|
||||
from core.graphql.filters import ContentTypeFilter
|
||||
from dcim.graphql.filters import SiteFilter
|
||||
@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin):
|
||||
|
||||
@strawberry_django.filter_type(models.VLANGroup, lookups=True)
|
||||
class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
|
||||
vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
|
||||
vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
|
||||
strawberry_django.filter_field()
|
||||
)
|
||||
|
||||
|
||||
@ -1723,6 +1723,10 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
|
||||
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
|
||||
params = {'contains_vid': 1}
|
||||
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
|
||||
params = {'contains_vid': 12} # 11 is NOT in [1,11)
|
||||
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
|
||||
params = {'contains_vid': 4095}
|
||||
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0)
|
||||
|
||||
def test_region(self):
|
||||
params = {'region': Region.objects.first().pk}
|
||||
|
||||
66
netbox/ipam/tests/test_lookups.py
Normal file
66
netbox/ipam/tests/test_lookups.py
Normal file
@ -0,0 +1,66 @@
|
||||
from django.test import TestCase
|
||||
from django.db.backends.postgresql.psycopg_any import NumericRange
|
||||
from ipam.models import VLANGroup
|
||||
|
||||
|
||||
class VLANGroupRangeContainsLookupTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
# Two ranges: [1,11) and [20,31)
|
||||
cls.g1 = VLANGroup.objects.create(
|
||||
name='VlanGroup-A',
|
||||
slug='VlanGroup-A',
|
||||
vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)],
|
||||
)
|
||||
# One range: [100,201)
|
||||
cls.g2 = VLANGroup.objects.create(
|
||||
name='VlanGroup-B',
|
||||
slug='VlanGroup-B',
|
||||
vid_ranges=[NumericRange(100, 201)],
|
||||
)
|
||||
cls.g_empty = VLANGroup.objects.create(
|
||||
name='VlanGroup-empty',
|
||||
slug='VlanGroup-empty',
|
||||
vid_ranges=[],
|
||||
)
|
||||
|
||||
def test_contains_value_in_first_range(self):
|
||||
"""
|
||||
Tests whether a specific value is contained within the first range in a queried
|
||||
set of VLANGroup objects.
|
||||
"""
|
||||
names = list(
|
||||
VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name')
|
||||
)
|
||||
self.assertEqual(names, ['VlanGroup-A'])
|
||||
|
||||
def test_contains_value_in_second_range(self):
|
||||
"""
|
||||
Tests if a value exists in the second range of VLANGroup objects and
|
||||
validates the result against the expected list of names.
|
||||
"""
|
||||
names = list(
|
||||
VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name')
|
||||
)
|
||||
self.assertEqual(names, ['VlanGroup-A'])
|
||||
|
||||
def test_upper_bound_is_exclusive(self):
|
||||
"""
|
||||
Tests if the upper bound of the range is exclusive in the filter method.
|
||||
"""
|
||||
# 11 is NOT in [1,11)
|
||||
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists())
|
||||
|
||||
def test_no_match_far_outside(self):
|
||||
"""
|
||||
Tests that no VLANGroup contains a VID within a specified range far outside
|
||||
common VID bounds and returns `False`.
|
||||
"""
|
||||
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists())
|
||||
|
||||
def test_empty_array_never_matches(self):
|
||||
"""
|
||||
Tests the behavior of VLANGroup objects when an empty array is used to match a
|
||||
specific condition.
|
||||
"""
|
||||
self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists())
|
||||
@ -24,6 +24,7 @@ __all__ = (
|
||||
'FloatLookup',
|
||||
'IntegerArrayLookup',
|
||||
'IntegerLookup',
|
||||
'IntegerRangeArrayLookup',
|
||||
'JSONFilter',
|
||||
'StringArrayLookup',
|
||||
'TreeNodeFilter',
|
||||
@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]):
|
||||
@strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
|
||||
class StringArrayLookup(ArrayLookup[str]):
|
||||
pass
|
||||
|
||||
|
||||
@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.')
|
||||
class RangeArrayValueLookup(Generic[T]):
|
||||
"""
|
||||
class for Array field of Range fields lookups
|
||||
"""
|
||||
|
||||
contains: T | None = strawberry.field(
|
||||
default=strawberry.UNSET, description='Return rows where any stored range contains this value.'
|
||||
)
|
||||
|
||||
@strawberry_django.filter_field
|
||||
def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]:
|
||||
"""
|
||||
Map GraphQL: { <field>: { contains: <T> } } To Django ORM: <field>__range_contains=<T>
|
||||
"""
|
||||
if self.contains is strawberry.UNSET or self.contains is None:
|
||||
return queryset, Q()
|
||||
|
||||
# Build '<prefix>range_contains' so it works for nested paths too
|
||||
return queryset, Q(**{f'{prefix}range_contains': self.contains})
|
||||
|
||||
|
||||
@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.')
|
||||
class IntegerRangeArrayLookup(RangeArrayValueLookup[int]):
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user