diff --git a/netbox/extras/lookups.py b/netbox/extras/lookups.py index 33296340e..678239080 100644 --- a/netbox/extras/lookups.py +++ b/netbox/extras/lookups.py @@ -1,9 +1,39 @@ +from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.fields.ranges import RangeField from django.db.models import CharField, JSONField, Lookup from django.db.models.fields.json import KeyTextTransform from .fields import CachedValueField +class RangeContains(Lookup): + """ + Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS. + + Usage (ORM): + Model.objects.filter(__range_contains=) + + Works with int4range[], int8range[], daterange[], tstzrange[], etc. + """ + + lookup_name = 'range_contains' + + def as_sql(self, compiler, connection): + # Compile LHS (the array-of-ranges column/expression) and RHS (scalar) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + + # Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField + field = getattr(self.lhs, 'output_field', None) + if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)): + raise TypeError('range_contains is only valid for ArrayField(RangeField) columns') + + # Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value + sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})" + params = lhs_params + rhs_params + return sql, params + + class Empty(Lookup): """ Filter on whether a string is empty. @@ -25,7 +55,7 @@ class JSONEmpty(Lookup): A key is considered empty if it is "", null, or does not exist. """ - lookup_name = "empty" + lookup_name = 'empty' def as_sql(self, compiler, connection): # self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform) @@ -69,6 +99,7 @@ class NetContainsOrEquals(Lookup): return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params +ArrayField.register_lookup(RangeContains) CharField.register_lookup(Empty) JSONField.register_lookup(JSONEmpty) CachedValueField.register_lookup(NetHost) diff --git a/netbox/ipam/filtersets.py b/netbox/ipam/filtersets.py index 1d201fc38..34bc34b48 100644 --- a/netbox/ipam/filtersets.py +++ b/netbox/ipam/filtersets.py @@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet): method='filter_scope' ) contains_vid = django_filters.NumberFilter( - method='filter_contains_vid' + field_name='vid_ranges', + lookup_expr='range_contains', ) class Meta: @@ -931,21 +932,6 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet): scope_id=value ) - def filter_contains_vid(self, queryset, name, value): - """ - Return all VLANGroups which contain the given VLAN ID. - """ - table_name = VLANGroup._meta.db_table - # TODO: See if this can be optimized without compromising queryset integrity - # Expand VLAN ID ranges to query by integer - groups = VLANGroup.objects.raw( - f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range', - params=(value,) - ) - return queryset.filter( - pk__in=[g.id for g in groups] - ) - class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet): region_id = TreeNodeMultipleChoiceFilter( diff --git a/netbox/ipam/graphql/filters.py b/netbox/ipam/graphql/filters.py index 35ddd47e4..4b2431aa2 100644 --- a/netbox/ipam/graphql/filters.py +++ b/netbox/ipam/graphql/filters.py @@ -19,7 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin from virtualization.models import VMInterface if TYPE_CHECKING: - from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup + from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup from circuits.graphql.filters import ProviderFilter from core.graphql.filters import ContentTypeFilter from dcim.graphql.filters import SiteFilter @@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin): @strawberry_django.filter_type(models.VLANGroup, lookups=True) class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin): - vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = ( + vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = ( strawberry_django.filter_field() ) diff --git a/netbox/ipam/tests/test_filtersets.py b/netbox/ipam/tests/test_filtersets.py index 54ad5df90..d2cd13dce 100644 --- a/netbox/ipam/tests/test_filtersets.py +++ b/netbox/ipam/tests/test_filtersets.py @@ -1723,6 +1723,10 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'contains_vid': 1} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8) + params = {'contains_vid': 12} # 11 is NOT in [1,11) + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + params = {'contains_vid': 4095} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0) def test_region(self): params = {'region': Region.objects.first().pk} diff --git a/netbox/ipam/tests/test_lookups.py b/netbox/ipam/tests/test_lookups.py new file mode 100644 index 000000000..5c8b7a770 --- /dev/null +++ b/netbox/ipam/tests/test_lookups.py @@ -0,0 +1,66 @@ +from django.test import TestCase +from django.db.backends.postgresql.psycopg_any import NumericRange +from ipam.models import VLANGroup + + +class VLANGroupRangeContainsLookupTests(TestCase): + @classmethod + def setUpTestData(cls): + # Two ranges: [1,11) and [20,31) + cls.g1 = VLANGroup.objects.create( + name='VlanGroup-A', + slug='VlanGroup-A', + vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)], + ) + # One range: [100,201) + cls.g2 = VLANGroup.objects.create( + name='VlanGroup-B', + slug='VlanGroup-B', + vid_ranges=[NumericRange(100, 201)], + ) + cls.g_empty = VLANGroup.objects.create( + name='VlanGroup-empty', + slug='VlanGroup-empty', + vid_ranges=[], + ) + + def test_contains_value_in_first_range(self): + """ + Tests whether a specific value is contained within the first range in a queried + set of VLANGroup objects. + """ + names = list( + VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name') + ) + self.assertEqual(names, ['VlanGroup-A']) + + def test_contains_value_in_second_range(self): + """ + Tests if a value exists in the second range of VLANGroup objects and + validates the result against the expected list of names. + """ + names = list( + VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name') + ) + self.assertEqual(names, ['VlanGroup-A']) + + def test_upper_bound_is_exclusive(self): + """ + Tests if the upper bound of the range is exclusive in the filter method. + """ + # 11 is NOT in [1,11) + self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists()) + + def test_no_match_far_outside(self): + """ + Tests that no VLANGroup contains a VID within a specified range far outside + common VID bounds and returns `False`. + """ + self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists()) + + def test_empty_array_never_matches(self): + """ + Tests the behavior of VLANGroup objects when an empty array is used to match a + specific condition. + """ + self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists()) diff --git a/netbox/netbox/graphql/filter_lookups.py b/netbox/netbox/graphql/filter_lookups.py index 859236e4d..ef28d5731 100644 --- a/netbox/netbox/graphql/filter_lookups.py +++ b/netbox/netbox/graphql/filter_lookups.py @@ -24,6 +24,7 @@ __all__ = ( 'FloatLookup', 'IntegerArrayLookup', 'IntegerLookup', + 'IntegerRangeArrayLookup', 'JSONFilter', 'StringArrayLookup', 'TreeNodeFilter', @@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]): @strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.') class StringArrayLookup(ArrayLookup[str]): pass + + +@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.') +class RangeArrayValueLookup(Generic[T]): + """ + class for Array field of Range fields lookups + """ + + contains: T | None = strawberry.field( + default=strawberry.UNSET, description='Return rows where any stored range contains this value.' + ) + + @strawberry_django.filter_field + def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]: + """ + Map GraphQL: { : { contains: } } To Django ORM: __range_contains= + """ + if self.contains is strawberry.UNSET or self.contains is None: + return queryset, Q() + + # Build 'range_contains' so it works for nested paths too + return queryset, Q(**{f'{prefix}range_contains': self.contains}) + + +@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.') +class IntegerRangeArrayLookup(RangeArrayValueLookup[int]): + pass