diff --git a/netbox/dcim/models/cables.py b/netbox/dcim/models/cables.py index 69e07ed94..5adf30022 100644 --- a/netbox/dcim/models/cables.py +++ b/netbox/dcim/models/cables.py @@ -18,6 +18,7 @@ from utilities.conversion import to_meters from utilities.exceptions import AbortRequest from utilities.fields import ColorField, GenericArrayForeignKey from utilities.querysets import RestrictedQuerySet +from utilities.serialization import serialize_object from wireless.models import WirelessLink from .device_components import FrontPort, RearPort, PathEndpoint @@ -119,6 +120,9 @@ class Cable(PrimaryModel): pk = self.pk or self._pk return self.label or f'#{pk}' + def get_status_color(self): + return LinkStatusChoices.colors.get(self.status) + @property def a_terminations(self): if hasattr(self, '_a_terminations'): @@ -208,7 +212,7 @@ class Cable(PrimaryModel): for termination in self.b_terminations: CableTermination(cable=self, cable_end='B', termination=termination).clean() - def save(self, *args, **kwargs): + def save(self, *args, force_insert=False, force_update=False, using=None, update_fields=None): _created = self.pk is None # Store the given length (if any) in meters for use in database ordering @@ -221,39 +225,69 @@ class Cable(PrimaryModel): if self.length is None: self.length_unit = None - super().save(*args, **kwargs) + # If this is a new Cable, save it before attempting to create its CableTerminations + if self._state.adding: + super().save(*args, force_insert=True, using=using, update_fields=update_fields) + # Update the private PK used in __str__() + self._pk = self.pk - # Update the private pk used in __str__ in case this is a new object (i.e. just got its pk) - self._pk = self.pk - - # Retrieve existing A/B terminations for the Cable - a_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='A')} - b_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='B')} - - # Delete stale CableTerminations if self._terminations_modified: - for termination, ct in a_terminations.items(): - if termination.pk and termination not in self.a_terminations: - ct.delete() - for termination, ct in b_terminations.items(): - if termination.pk and termination not in self.b_terminations: - ct.delete() + self.update_terminations() + + super().save(*args, force_update=True, using=using, update_fields=update_fields) - # Save new CableTerminations (if any) - if self._terminations_modified: - for termination in self.a_terminations: - if not termination.pk or termination not in a_terminations: - CableTermination(cable=self, cable_end='A', termination=termination).save() - for termination in self.b_terminations: - if not termination.pk or termination not in b_terminations: - CableTermination(cable=self, cable_end='B', termination=termination).save() try: trace_paths.send(Cable, instance=self, created=_created) except UnsupportedCablePath as e: raise AbortRequest(e) - def get_status_color(self): - return LinkStatusChoices.colors.get(self.status) + def serialize_object(self, exclude=None): + data = serialize_object(self, exclude=exclude or []) + + # Add A & B terminations to the serialized data + a_terminations, b_terminations = self.get_terminations() + data['a_terminations'] = sorted([ct.pk for ct in a_terminations.values()]) + data['b_terminations'] = sorted([ct.pk for ct in b_terminations.values()]) + + return data + + def get_terminations(self): + """ + Return two dictionaries mapping A & B side terminating objects to their corresponding CableTerminations + for this Cable. + """ + a_terminations = {} + b_terminations = {} + + for ct in CableTermination.objects.filter(cable=self).prefetch_related('termination'): + if ct.cable_end == CableEndChoices.SIDE_A: + a_terminations[ct.termination] = ct + else: + b_terminations[ct.termination] = ct + + return a_terminations, b_terminations + + def update_terminations(self): + """ + Create/delete CableTerminations for this Cable to reflect its current state. + """ + a_terminations, b_terminations = self.get_terminations() + + # Delete any stale CableTerminations + for termination, ct in a_terminations.items(): + if termination.pk and termination not in self.a_terminations: + ct.delete() + for termination, ct in b_terminations.items(): + if termination.pk and termination not in self.b_terminations: + ct.delete() + + # Save any new CableTerminations + for termination in self.a_terminations: + if not termination.pk or termination not in a_terminations: + CableTermination(cable=self, cable_end='A', termination=termination).save() + for termination in self.b_terminations: + if not termination.pk or termination not in b_terminations: + CableTermination(cable=self, cable_end='B', termination=termination).save() class CableTermination(ChangeLoggedModel): diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 8df8f4438..1fe881367 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -247,9 +247,9 @@ class APIViewTestCases: if issubclass(self.model, ChangeLoggingMixin): objectchange = ObjectChange.objects.get( changed_object_type=ContentType.objects.get_for_model(instance), - changed_object_id=instance.pk + changed_object_id=instance.pk, + action=ObjectChangeActionChoices.ACTION_CREATE, ) - self.assertEqual(objectchange.action, ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(objectchange.message, data['changelog_message']) def test_bulk_create_objects(self): @@ -298,11 +298,11 @@ class APIViewTestCases: ] objectchanges = ObjectChange.objects.filter( changed_object_type=ContentType.objects.get_for_model(self.model), - changed_object_id__in=id_list + changed_object_id__in=id_list, + action=ObjectChangeActionChoices.ACTION_CREATE, ) self.assertEqual(len(objectchanges), len(self.create_data)) for oc in objectchanges: - self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(oc.message, changelog_message) class UpdateObjectViewTestCase(APITestCase): diff --git a/netbox/utilities/testing/views.py b/netbox/utilities/testing/views.py index da8a87098..99a6dd43a 100644 --- a/netbox/utilities/testing/views.py +++ b/netbox/utilities/testing/views.py @@ -655,11 +655,11 @@ class ViewTestCases: self.assertIsNotNone(request_id, "Unable to determine request ID from response") objectchanges = ObjectChange.objects.filter( changed_object_type=ContentType.objects.get_for_model(self.model), - request_id=request_id + request_id=request_id, + action=ObjectChangeActionChoices.ACTION_CREATE, ) self.assertEqual(len(objectchanges), len(self.csv_data) - 1) for oc in objectchanges: - self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(oc.message, data['changelog_message']) @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])