mirror of
https://github.com/netbox-community/netbox.git
synced 2025-12-11 02:49:35 -06:00
feat(extras): Add range_contains ORM lookup
Introduce a generic lookup for ArrayField(RangeField) that matches rows where a scalar value is contained by any range in the array (e.g. VLANGroup.vid_ranges). Replace the raw-SQL helper in the VLANGroup FilterSet (`contains_vid`) with the ORM lookup for better maintainability. Add tests for the lookup and the FilterSet behavior. Closes #20497
This commit is contained in:
parent
2abc5ac69a
commit
33d4759871
@ -1,9 +1,39 @@
|
|||||||
|
from django.contrib.postgres.fields import ArrayField
|
||||||
|
from django.contrib.postgres.fields.ranges import RangeField
|
||||||
from django.db.models import CharField, JSONField, Lookup
|
from django.db.models import CharField, JSONField, Lookup
|
||||||
from django.db.models.fields.json import KeyTextTransform
|
from django.db.models.fields.json import KeyTextTransform
|
||||||
|
|
||||||
from .fields import CachedValueField
|
from .fields import CachedValueField
|
||||||
|
|
||||||
|
|
||||||
|
class RangeContains(Lookup):
|
||||||
|
"""
|
||||||
|
Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS.
|
||||||
|
|
||||||
|
Usage (ORM):
|
||||||
|
Model.objects.filter(<range_array_field>__range_contains=<scalar>)
|
||||||
|
|
||||||
|
Works with int4range[], int8range[], daterange[], tstzrange[], etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lookup_name = 'range_contains'
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
# Compile LHS (the array-of-ranges column/expression) and RHS (scalar)
|
||||||
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
|
|
||||||
|
# Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField
|
||||||
|
field = getattr(self.lhs, 'output_field', None)
|
||||||
|
if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)):
|
||||||
|
raise TypeError('range_contains is only valid for ArrayField(RangeField) columns')
|
||||||
|
|
||||||
|
# Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value
|
||||||
|
sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})"
|
||||||
|
params = lhs_params + rhs_params
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
class Empty(Lookup):
|
class Empty(Lookup):
|
||||||
"""
|
"""
|
||||||
Filter on whether a string is empty.
|
Filter on whether a string is empty.
|
||||||
@ -25,7 +55,7 @@ class JSONEmpty(Lookup):
|
|||||||
|
|
||||||
A key is considered empty if it is "", null, or does not exist.
|
A key is considered empty if it is "", null, or does not exist.
|
||||||
"""
|
"""
|
||||||
lookup_name = "empty"
|
lookup_name = 'empty'
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
# self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
|
# self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
|
||||||
@ -69,6 +99,7 @@ class NetContainsOrEquals(Lookup):
|
|||||||
return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
|
return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
|
||||||
|
|
||||||
|
|
||||||
|
ArrayField.register_lookup(RangeContains)
|
||||||
CharField.register_lookup(Empty)
|
CharField.register_lookup(Empty)
|
||||||
JSONField.register_lookup(JSONEmpty)
|
JSONField.register_lookup(JSONEmpty)
|
||||||
CachedValueField.register_lookup(NetHost)
|
CachedValueField.register_lookup(NetHost)
|
||||||
|
|||||||
@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
|
|||||||
method='filter_scope'
|
method='filter_scope'
|
||||||
)
|
)
|
||||||
contains_vid = django_filters.NumberFilter(
|
contains_vid = django_filters.NumberFilter(
|
||||||
method='filter_contains_vid'
|
field_name='vid_ranges',
|
||||||
|
lookup_expr='range_contains',
|
||||||
)
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
@ -931,21 +932,6 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
|
|||||||
scope_id=value
|
scope_id=value
|
||||||
)
|
)
|
||||||
|
|
||||||
def filter_contains_vid(self, queryset, name, value):
|
|
||||||
"""
|
|
||||||
Return all VLANGroups which contain the given VLAN ID.
|
|
||||||
"""
|
|
||||||
table_name = VLANGroup._meta.db_table
|
|
||||||
# TODO: See if this can be optimized without compromising queryset integrity
|
|
||||||
# Expand VLAN ID ranges to query by integer
|
|
||||||
groups = VLANGroup.objects.raw(
|
|
||||||
f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range',
|
|
||||||
params=(value,)
|
|
||||||
)
|
|
||||||
return queryset.filter(
|
|
||||||
pk__in=[g.id for g in groups]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
|
class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
|
||||||
region_id = TreeNodeMultipleChoiceFilter(
|
region_id = TreeNodeMultipleChoiceFilter(
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin
|
|||||||
from virtualization.models import VMInterface
|
from virtualization.models import VMInterface
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup
|
from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup
|
||||||
from circuits.graphql.filters import ProviderFilter
|
from circuits.graphql.filters import ProviderFilter
|
||||||
from core.graphql.filters import ContentTypeFilter
|
from core.graphql.filters import ContentTypeFilter
|
||||||
from dcim.graphql.filters import SiteFilter
|
from dcim.graphql.filters import SiteFilter
|
||||||
@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin):
|
|||||||
|
|
||||||
@strawberry_django.filter_type(models.VLANGroup, lookups=True)
|
@strawberry_django.filter_type(models.VLANGroup, lookups=True)
|
||||||
class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
|
class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
|
||||||
vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
|
vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
|
||||||
strawberry_django.filter_field()
|
strawberry_django.filter_field()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1723,6 +1723,10 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
|
|||||||
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
|
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
|
||||||
params = {'contains_vid': 1}
|
params = {'contains_vid': 1}
|
||||||
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
|
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
|
||||||
|
params = {'contains_vid': 12} # 11 is NOT in [1,11)
|
||||||
|
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
|
||||||
|
params = {'contains_vid': 4095}
|
||||||
|
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0)
|
||||||
|
|
||||||
def test_region(self):
|
def test_region(self):
|
||||||
params = {'region': Region.objects.first().pk}
|
params = {'region': Region.objects.first().pk}
|
||||||
|
|||||||
66
netbox/ipam/tests/test_lookups.py
Normal file
66
netbox/ipam/tests/test_lookups.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from django.test import TestCase
|
||||||
|
from django.db.backends.postgresql.psycopg_any import NumericRange
|
||||||
|
from ipam.models import VLANGroup
|
||||||
|
|
||||||
|
|
||||||
|
class VLANGroupRangeContainsLookupTests(TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
# Two ranges: [1,11) and [20,31)
|
||||||
|
cls.g1 = VLANGroup.objects.create(
|
||||||
|
name='VlanGroup-A',
|
||||||
|
slug='VlanGroup-A',
|
||||||
|
vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)],
|
||||||
|
)
|
||||||
|
# One range: [100,201)
|
||||||
|
cls.g2 = VLANGroup.objects.create(
|
||||||
|
name='VlanGroup-B',
|
||||||
|
slug='VlanGroup-B',
|
||||||
|
vid_ranges=[NumericRange(100, 201)],
|
||||||
|
)
|
||||||
|
cls.g_empty = VLANGroup.objects.create(
|
||||||
|
name='VlanGroup-empty',
|
||||||
|
slug='VlanGroup-empty',
|
||||||
|
vid_ranges=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_contains_value_in_first_range(self):
|
||||||
|
"""
|
||||||
|
Tests whether a specific value is contained within the first range in a queried
|
||||||
|
set of VLANGroup objects.
|
||||||
|
"""
|
||||||
|
names = list(
|
||||||
|
VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name')
|
||||||
|
)
|
||||||
|
self.assertEqual(names, ['VlanGroup-A'])
|
||||||
|
|
||||||
|
def test_contains_value_in_second_range(self):
|
||||||
|
"""
|
||||||
|
Tests if a value exists in the second range of VLANGroup objects and
|
||||||
|
validates the result against the expected list of names.
|
||||||
|
"""
|
||||||
|
names = list(
|
||||||
|
VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name')
|
||||||
|
)
|
||||||
|
self.assertEqual(names, ['VlanGroup-A'])
|
||||||
|
|
||||||
|
def test_upper_bound_is_exclusive(self):
|
||||||
|
"""
|
||||||
|
Tests if the upper bound of the range is exclusive in the filter method.
|
||||||
|
"""
|
||||||
|
# 11 is NOT in [1,11)
|
||||||
|
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists())
|
||||||
|
|
||||||
|
def test_no_match_far_outside(self):
|
||||||
|
"""
|
||||||
|
Tests that no VLANGroup contains a VID within a specified range far outside
|
||||||
|
common VID bounds and returns `False`.
|
||||||
|
"""
|
||||||
|
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists())
|
||||||
|
|
||||||
|
def test_empty_array_never_matches(self):
|
||||||
|
"""
|
||||||
|
Tests the behavior of VLANGroup objects when an empty array is used to match a
|
||||||
|
specific condition.
|
||||||
|
"""
|
||||||
|
self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists())
|
||||||
@ -24,6 +24,7 @@ __all__ = (
|
|||||||
'FloatLookup',
|
'FloatLookup',
|
||||||
'IntegerArrayLookup',
|
'IntegerArrayLookup',
|
||||||
'IntegerLookup',
|
'IntegerLookup',
|
||||||
|
'IntegerRangeArrayLookup',
|
||||||
'JSONFilter',
|
'JSONFilter',
|
||||||
'StringArrayLookup',
|
'StringArrayLookup',
|
||||||
'TreeNodeFilter',
|
'TreeNodeFilter',
|
||||||
@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]):
|
|||||||
@strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
|
@strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
|
||||||
class StringArrayLookup(ArrayLookup[str]):
|
class StringArrayLookup(ArrayLookup[str]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.')
|
||||||
|
class RangeArrayValueLookup(Generic[T]):
|
||||||
|
"""
|
||||||
|
class for Array field of Range fields lookups
|
||||||
|
"""
|
||||||
|
|
||||||
|
contains: T | None = strawberry.field(
|
||||||
|
default=strawberry.UNSET, description='Return rows where any stored range contains this value.'
|
||||||
|
)
|
||||||
|
|
||||||
|
@strawberry_django.filter_field
|
||||||
|
def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]:
|
||||||
|
"""
|
||||||
|
Map GraphQL: { <field>: { contains: <T> } } To Django ORM: <field>__range_contains=<T>
|
||||||
|
"""
|
||||||
|
if self.contains is strawberry.UNSET or self.contains is None:
|
||||||
|
return queryset, Q()
|
||||||
|
|
||||||
|
# Build '<prefix>range_contains' so it works for nested paths too
|
||||||
|
return queryset, Q(**{f'{prefix}range_contains': self.contains})
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.')
|
||||||
|
class IntegerRangeArrayLookup(RangeArrayValueLookup[int]):
|
||||||
|
pass
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user