feat(extras): Add range_contains ORM lookup

Introduce a generic lookup for ArrayField(RangeField) that matches rows
where a scalar value is contained by any range in the array
(e.g. VLANGroup.vid_ranges).
Replace the raw-SQL helper in the VLANGroup FilterSet (`contains_vid`)
with the ORM lookup for better maintainability.
Add tests for the lookup and the FilterSet behavior.

Closes #20497
This commit is contained in:
Martin Hauser 2025-10-08 13:00:56 +02:00 committed by Jeremy Stretch
parent 2abc5ac69a
commit 33d4759871
6 changed files with 134 additions and 19 deletions

View File

@ -1,9 +1,39 @@
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.fields.ranges import RangeField
from django.db.models import CharField, JSONField, Lookup from django.db.models import CharField, JSONField, Lookup
from django.db.models.fields.json import KeyTextTransform from django.db.models.fields.json import KeyTextTransform
from .fields import CachedValueField from .fields import CachedValueField
class RangeContains(Lookup):
"""
Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS.
Usage (ORM):
Model.objects.filter(<range_array_field>__range_contains=<scalar>)
Works with int4range[], int8range[], daterange[], tstzrange[], etc.
"""
lookup_name = 'range_contains'
def as_sql(self, compiler, connection):
# Compile LHS (the array-of-ranges column/expression) and RHS (scalar)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
# Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField
field = getattr(self.lhs, 'output_field', None)
if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)):
raise TypeError('range_contains is only valid for ArrayField(RangeField) columns')
# Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value
sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})"
params = lhs_params + rhs_params
return sql, params
class Empty(Lookup): class Empty(Lookup):
""" """
Filter on whether a string is empty. Filter on whether a string is empty.
@ -25,7 +55,7 @@ class JSONEmpty(Lookup):
A key is considered empty if it is "", null, or does not exist. A key is considered empty if it is "", null, or does not exist.
""" """
lookup_name = "empty" lookup_name = 'empty'
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform) # self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
@ -69,6 +99,7 @@ class NetContainsOrEquals(Lookup):
return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
ArrayField.register_lookup(RangeContains)
CharField.register_lookup(Empty) CharField.register_lookup(Empty)
JSONField.register_lookup(JSONEmpty) JSONField.register_lookup(JSONEmpty)
CachedValueField.register_lookup(NetHost) CachedValueField.register_lookup(NetHost)

View File

@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
method='filter_scope' method='filter_scope'
) )
contains_vid = django_filters.NumberFilter( contains_vid = django_filters.NumberFilter(
method='filter_contains_vid' field_name='vid_ranges',
lookup_expr='range_contains',
) )
class Meta: class Meta:
@ -931,21 +932,6 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
scope_id=value scope_id=value
) )
def filter_contains_vid(self, queryset, name, value):
"""
Return all VLANGroups which contain the given VLAN ID.
"""
table_name = VLANGroup._meta.db_table
# TODO: See if this can be optimized without compromising queryset integrity
# Expand VLAN ID ranges to query by integer
groups = VLANGroup.objects.raw(
f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range',
params=(value,)
)
return queryset.filter(
pk__in=[g.id for g in groups]
)
class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet): class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
region_id = TreeNodeMultipleChoiceFilter( region_id = TreeNodeMultipleChoiceFilter(

View File

@ -19,7 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin
from virtualization.models import VMInterface from virtualization.models import VMInterface
if TYPE_CHECKING: if TYPE_CHECKING:
from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup
from circuits.graphql.filters import ProviderFilter from circuits.graphql.filters import ProviderFilter
from core.graphql.filters import ContentTypeFilter from core.graphql.filters import ContentTypeFilter
from dcim.graphql.filters import SiteFilter from dcim.graphql.filters import SiteFilter
@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin):
@strawberry_django.filter_type(models.VLANGroup, lookups=True) @strawberry_django.filter_type(models.VLANGroup, lookups=True)
class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin): class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = ( vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
strawberry_django.filter_field() strawberry_django.filter_field()
) )

View File

@ -1723,6 +1723,10 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'contains_vid': 1} params = {'contains_vid': 1}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
params = {'contains_vid': 12} # 11 is NOT in [1,11)
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'contains_vid': 4095}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0)
def test_region(self): def test_region(self):
params = {'region': Region.objects.first().pk} params = {'region': Region.objects.first().pk}

View File

@ -0,0 +1,66 @@
from django.test import TestCase
from django.db.backends.postgresql.psycopg_any import NumericRange
from ipam.models import VLANGroup
class VLANGroupRangeContainsLookupTests(TestCase):
@classmethod
def setUpTestData(cls):
# Two ranges: [1,11) and [20,31)
cls.g1 = VLANGroup.objects.create(
name='VlanGroup-A',
slug='VlanGroup-A',
vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)],
)
# One range: [100,201)
cls.g2 = VLANGroup.objects.create(
name='VlanGroup-B',
slug='VlanGroup-B',
vid_ranges=[NumericRange(100, 201)],
)
cls.g_empty = VLANGroup.objects.create(
name='VlanGroup-empty',
slug='VlanGroup-empty',
vid_ranges=[],
)
def test_contains_value_in_first_range(self):
"""
Tests whether a specific value is contained within the first range in a queried
set of VLANGroup objects.
"""
names = list(
VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name')
)
self.assertEqual(names, ['VlanGroup-A'])
def test_contains_value_in_second_range(self):
"""
Tests if a value exists in the second range of VLANGroup objects and
validates the result against the expected list of names.
"""
names = list(
VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name')
)
self.assertEqual(names, ['VlanGroup-A'])
def test_upper_bound_is_exclusive(self):
"""
Tests if the upper bound of the range is exclusive in the filter method.
"""
# 11 is NOT in [1,11)
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists())
def test_no_match_far_outside(self):
"""
Tests that no VLANGroup contains a VID within a specified range far outside
common VID bounds and returns `False`.
"""
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists())
def test_empty_array_never_matches(self):
"""
Tests the behavior of VLANGroup objects when an empty array is used to match a
specific condition.
"""
self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists())

View File

@ -24,6 +24,7 @@ __all__ = (
'FloatLookup', 'FloatLookup',
'IntegerArrayLookup', 'IntegerArrayLookup',
'IntegerLookup', 'IntegerLookup',
'IntegerRangeArrayLookup',
'JSONFilter', 'JSONFilter',
'StringArrayLookup', 'StringArrayLookup',
'TreeNodeFilter', 'TreeNodeFilter',
@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]):
@strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.') @strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
class StringArrayLookup(ArrayLookup[str]): class StringArrayLookup(ArrayLookup[str]):
pass pass
@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.')
class RangeArrayValueLookup(Generic[T]):
"""
class for Array field of Range fields lookups
"""
contains: T | None = strawberry.field(
default=strawberry.UNSET, description='Return rows where any stored range contains this value.'
)
@strawberry_django.filter_field
def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]:
"""
Map GraphQL: { <field>: { contains: <T> } } To Django ORM: <field>__range_contains=<T>
"""
if self.contains is strawberry.UNSET or self.contains is None:
return queryset, Q()
# Build '<prefix>range_contains' so it works for nested paths too
return queryset, Q(**{f'{prefix}range_contains': self.contains})
@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.')
class IntegerRangeArrayLookup(RangeArrayValueLookup[int]):
pass