diff --git a/netbox/ipam/filtersets.py b/netbox/ipam/filtersets.py index 2e9f56bbc..9e4a84eba 100644 --- a/netbox/ipam/filtersets.py +++ b/netbox/ipam/filtersets.py @@ -16,6 +16,7 @@ from virtualization.models import VirtualMachine, VMInterface from .choices import * from .models import * +from rest_framework import serializers __all__ = ( 'AggregateFilterSet', @@ -599,8 +600,24 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet): return queryset.none() return queryset.filter(q) + def parse_inet_addresses(self, value): + try: + parsed = [] + for addr in value: + if netaddr.valid_ipv4(addr) or netaddr.valid_ipv6(addr): + parsed.append(addr) + continue + network = netaddr.IPNetwork(addr) + parsed.append(str(network)) + return parsed + except (AddrFormatError, ValueError): + raise serializers.ValidationError({ + 'address': f'Invalid address {addr}. It must be a valid IPv4 or IPv6 address or network' + }) + def filter_address(self, queryset, name, value): try: + value = self.parse_inet_addresses(value) return queryset.filter(address__net_in=value) except ValidationError: return queryset.none() diff --git a/netbox/ipam/tests/test_filtersets.py b/netbox/ipam/tests/test_filtersets.py index 13b3ae163..e4c4abd0b 100644 --- a/netbox/ipam/tests/test_filtersets.py +++ b/netbox/ipam/tests/test_filtersets.py @@ -10,6 +10,7 @@ from ipam.models import * from utilities.testing import ChangeLoggedFilterSetTests, create_test_device, create_test_virtualmachine from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface from tenancy.models import Tenant, TenantGroup +from rest_framework import serializers class ASNTestCase(TestCase, ChangeLoggedFilterSetTests): @@ -851,6 +852,32 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests): params = {'address': ['2001:db8::1/64', '2001:db8::1/65']} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + # Check for valid edge cases. Note that Postgres inet type + # only accepts netmasks in the int form, so the filterset + # casts netmasks in the xxx.xxx.xxx.xxx format. + params = {'address': ['24']} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0) + params = {'address': ['10.0.0.1/255.255.255.0']} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) + params = {'address': ['10.0.0.1/255.255.255.0', '10.0.0.1/25']} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + # Check for invalid input. + params = {'address': ['/24']} + with self.assertRaises(serializers.ValidationError) as cm: + self.filterset(params, self.queryset).qs.count() + self.assertRegex(cm.exception.detail['address'], r'^Invalid address.*') + + params = {'address': ['10.0.0.1/255.255.555.0']} + with self.assertRaises(serializers.ValidationError) as cm: + self.filterset(params, self.queryset).qs.count() + self.assertRegex(cm.exception.detail['address'], r'^Invalid address.*') + + params = {'address': ['10.0.0.1', '/24']} + with self.assertRaises(serializers.ValidationError) as cm: + self.filterset(params, self.queryset).qs.count() + self.assertRegex(cm.exception.detail['address'], r'^Invalid address.*') + def test_mask_length(self): params = {'mask_length': '24'} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)