diff --git a/netbox/dcim/fields.py b/netbox/dcim/fields.py index 23b9b7fd5..21af2ed14 100644 --- a/netbox/dcim/fields.py +++ b/netbox/dcim/fields.py @@ -1,11 +1,11 @@ from django.contrib.postgres.fields import ArrayField -from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core.exceptions import ValidationError from django.core.validators import MinValueValidator, MaxValueValidator from django.db import models from netaddr import AddrFormatError, EUI, mac_unix_expanded from ipam.constants import BGP_ASN_MAX, BGP_ASN_MIN +from .lookups import PathContains class ASNField(models.BigIntegerField): @@ -61,3 +61,6 @@ class PathField(ArrayField): def __init__(self, **kwargs): kwargs['base_field'] = models.CharField(max_length=40) super().__init__(**kwargs) + + +PathField.register_lookup(PathContains) diff --git a/netbox/dcim/lookups.py b/netbox/dcim/lookups.py new file mode 100644 index 000000000..03acc478a --- /dev/null +++ b/netbox/dcim/lookups.py @@ -0,0 +1,10 @@ +from django.contrib.postgres.fields.array import ArrayContains + +from dcim.utils import object_to_path_node + + +class PathContains(ArrayContains): + + def get_prep_lookup(self): + self.rhs = [object_to_path_node(self.rhs)] + return super().get_prep_lookup() diff --git a/netbox/dcim/signals.py b/netbox/dcim/signals.py index 6079417c1..ee006c9d7 100644 --- a/netbox/dcim/signals.py +++ b/netbox/dcim/signals.py @@ -7,7 +7,7 @@ from django.dispatch import receiver from .choices import CableStatusChoices from .models import Cable, CablePath, Device, PathEndpoint, VirtualChassis -from .utils import object_to_path_node, trace_path +from .utils import trace_path def create_cablepath(node): @@ -24,8 +24,7 @@ def rebuild_paths(obj): """ Rebuild all CablePaths which traverse the specified node """ - node = object_to_path_node(obj) - cable_paths = CablePath.objects.filter(path__contains=[node]) + cable_paths = CablePath.objects.filter(path__contains=obj) with transaction.atomic(): for cp in cable_paths: @@ -86,7 +85,7 @@ def update_connected_endpoints(instance, created, **kwargs): # may change in the future.) However, we do need to capture status changes and update # any CablePaths accordingly. if instance.status != CableStatusChoices.STATUS_CONNECTED: - CablePath.objects.filter(path__contains=[object_to_path_node(instance)]).update(is_active=False) + CablePath.objects.filter(path__contains=instance).update(is_active=False) else: rebuild_paths(instance) @@ -109,7 +108,7 @@ def nullify_connected_endpoints(instance, **kwargs): instance.termination_b.save() # Delete and retrace any dependent cable paths - for cablepath in CablePath.objects.filter(path__contains=[object_to_path_node(instance)]): + for cablepath in CablePath.objects.filter(path__contains=instance): path, destination, is_active = trace_path(cablepath.origin) if path: CablePath.objects.filter(pk=cablepath.pk).update( diff --git a/netbox/dcim/tests/test_cablepaths.py b/netbox/dcim/tests/test_cablepaths.py index 65de412cf..cfe63929d 100644 --- a/netbox/dcim/tests/test_cablepaths.py +++ b/netbox/dcim/tests/test_cablepaths.py @@ -4,7 +4,7 @@ from django.test import TestCase from circuits.models import * from dcim.choices import CableStatusChoices from dcim.models import * -from dcim.utils import objects_to_path +from dcim.utils import object_to_path_node class CablePathTestCase(TestCase): @@ -146,7 +146,7 @@ class CablePathTestCase(TestCase): kwargs['destination_type__isnull'] = True kwargs['destination_id__isnull'] = True if path is not None: - kwargs['path'] = objects_to_path(*path) + kwargs['path'] = [object_to_path_node(obj) for obj in path] if is_active is not None: kwargs['is_active'] = is_active if msg is None: diff --git a/netbox/dcim/utils.py b/netbox/dcim/utils.py index 186ea72e5..d36cb1ad3 100644 --- a/netbox/dcim/utils.py +++ b/netbox/dcim/utils.py @@ -8,10 +8,6 @@ def object_to_path_node(obj): return f'{obj._meta.model_name}:{obj.pk}' -def objects_to_path(*obj_list): - return [object_to_path_node(obj) for obj in obj_list] - - def path_node_to_object(repr): model_name, object_id = repr.split(':') model_class = ContentType.objects.get(model=model_name).model_class() diff --git a/netbox/dcim/views.py b/netbox/dcim/views.py index 3608b3792..63711a863 100644 --- a/netbox/dcim/views.py +++ b/netbox/dcim/views.py @@ -38,7 +38,6 @@ from .models import ( PowerPort, PowerPortTemplate, Rack, RackGroup, RackReservation, RackRole, RearPort, RearPortTemplate, Region, Site, VirtualChassis, ) -from .utils import object_to_path_node class BulkDisconnectView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): @@ -1974,9 +1973,7 @@ class PathTraceView(ObjectView): path = obj._path # Otherwise, find all CablePaths which traverse the specified object else: - related_paths = CablePath.objects.filter( - path__contains=[object_to_path_node(obj)] - ).prefetch_related('origin') + related_paths = CablePath.objects.filter(path__contains=obj).prefetch_related('origin') # Check for specification of a particular path (when tracing pass-through ports) try: path_id = int(request.GET.get('cablepath_id'))