Catch AssertionError from cable trace and throw ValidationError

This commit is contained in:
Daniel Sheppard 2024-06-02 22:06:34 -05:00
parent e18e6cf756
commit 47f3994f93
2 changed files with 20 additions and 8 deletions

View File

@ -237,7 +237,10 @@ class Cable(PrimaryModel):
if not termination.pk or termination not in b_terminations: if not termination.pk or termination not in b_terminations:
CableTermination(cable=self, cable_end='B', termination=termination).save() CableTermination(cable=self, cable_end='B', termination=termination).save()
trace_paths.send(Cable, instance=self, created=_created) try:
trace_paths.send(Cable, instance=self, created=_created)
except ValidationError as e:
raise ValidationError(f'{e}')
def get_status_color(self): def get_status_color(self):
return LinkStatusChoices.colors.get(self.status) return LinkStatusChoices.colors.get(self.status)
@ -532,7 +535,8 @@ class CablePath(models.Model):
# Ensure all originating terminations are attached to the same link # Ensure all originating terminations are attached to the same link
if len(terminations) > 1: if len(terminations) > 1:
assert all(t.link == terminations[0].link for t in terminations[1:]) assert all(t.link == terminations[0].link for t in terminations[1:]), \
"All originating terminations must start must be attached to the same link"
path = [] path = []
position_stack = [] position_stack = []
@ -543,12 +547,13 @@ class CablePath(models.Model):
while terminations: while terminations:
# Terminations must all be of the same type # Terminations must all be of the same type
assert all(isinstance(t, type(terminations[0])) for t in terminations[1:]) assert all(isinstance(t, type(terminations[0])) for t in terminations[1:]), \
"All mid-span terminations must have the same termination type"
# All mid-span terminations must all be attached to the same device # All mid-span terminations must all be attached to the same device
if not isinstance(terminations[0], PathEndpoint): if not isinstance(terminations[0], PathEndpoint):
assert all(isinstance(t, type(terminations[0])) for t in terminations[1:]) assert all(t.parent_object == terminations[0].parent_object for t in terminations[1:]), \
assert all(t.parent_object == terminations[0].parent_object for t in terminations[1:]) "All mid-span terminations must have the same parent object"
# Check for a split path (e.g. rear port fanning out to multiple front ports with # Check for a split path (e.g. rear port fanning out to multiple front ports with
# different cables attached) # different cables attached)
@ -571,8 +576,8 @@ class CablePath(models.Model):
return None return None
# Otherwise, halt the trace if no link exists # Otherwise, halt the trace if no link exists
break break
assert all(type(link) in (Cable, WirelessLink) for link in links) assert all(type(link) in (Cable, WirelessLink) for link in links), "All links must be cable or wireless"
assert all(isinstance(link, type(links[0])) for link in links) assert all(isinstance(link, type(links[0])) for link in links), "All links must match first link type"
# Step 3: Record asymmetric paths as split # Step 3: Record asymmetric paths as split
not_connected_terminations = [termination.link for termination in terminations if termination.link is None] not_connected_terminations = [termination.link for termination in terminations if termination.link is None]
@ -656,7 +661,7 @@ class CablePath(models.Model):
for rt in remote_terminations: for rt in remote_terminations:
position = positions.pop() position = positions.pop()
q_filter |= Q(rear_port_id=rt.pk, rear_port_position=position) q_filter |= Q(rear_port_id=rt.pk, rear_port_position=position)
assert q_filter is not Q() assert q_filter is not Q(), "Remote termination query filter is empty, please open a bug report"
front_ports = FrontPort.objects.filter(q_filter) front_ports = FrontPort.objects.filter(q_filter)
# Obtain the individual front ports based on the termination and position # Obtain the individual front ports based on the termination and position
elif position_stack: elif position_stack:

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from django.contrib import messages from django.contrib import messages
from django.core.exceptions import ValidationError
from django.db import router, transaction from django.db import router, transaction
from django.db.models import ProtectedError, RestrictedError from django.db.models import ProtectedError, RestrictedError
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
@ -307,6 +308,12 @@ class ObjectEditView(GetReturnURLMixin, BaseObjectView):
form.add_error(None, e.message) form.add_error(None, e.message)
clear_events.send(sender=self) clear_events.send(sender=self)
# Catch any validation errors thrown in the model.save() or form.save() methods
except ValidationError as e:
logger.debug(e.message)
form.add_error(None, e.message)
clear_events.send(sender=self)
else: else:
logger.debug("Form validation failed") logger.debug("Form validation failed")