diff --git a/netbox/utilities/counter.py b/netbox/utilities/counter.py index 8cb555c71..47b47feec 100644 --- a/netbox/utilities/counter.py +++ b/netbox/utilities/counter.py @@ -14,9 +14,9 @@ def post_save_receiver_counter(counter_instance, sender, instance, created, **kw # not created so check if field has changed field_name = f"{counter_instance.foreign_key_field.name}_id" - if field_name in instance.tracker.changed: + if field_name in instance.tracker: new_value = getattr(instance, field_name, None) - old_value = instance.tracker.changed[field_name] + old_value = instance.tracker.get(field_name) if (new_value is not None) and (new_value != old_value): counter_instance.adjust_count(new_value, 1) counter_instance.adjust_count(old_value, -1) diff --git a/netbox/utilities/mixins.py b/netbox/utilities/mixins.py index b56055a5e..88945615b 100644 --- a/netbox/utilities/mixins.py +++ b/netbox/utilities/mixins.py @@ -4,46 +4,75 @@ from netbox.registry import registry class Tracker: - def __init__(self, instance): - self.instance = instance - self.changed = {} + """ + An ephemeral instance employed to record which tracked fields on an instance have been modified. + """ + def __init__(self): + self._changed_fields = {} + + def __contains__(self, item): + return item in self._changed_fields + + def set(self, name, value): + """ + Mark an attribute as having been changed and record its original value. + """ + self._changed_fields[name] = value + + def get(self, name): + """ + Return the original value of a changed field. Raises KeyError if name is not found. + """ + return self._changed_fields[name] + + def clear(self, *names): + """ + Clear any fields that were recorded as having been changed. + """ + for name in names: + self._changed_fields.pop(name, None) + else: + self._changed_fields = {} class TrackingModelMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + # Mark the instance as initialized, to enable our custom __setattr__() self._initialized = True @property def tracker(self): + """ + Return the Tracker instance for this instance, first creating it if necessary. + """ if not hasattr(self._state, "_tracker"): - self._state._tracker = Tracker(self) + self._state._tracker = Tracker() return self._state._tracker def save(self, *args, **kwargs): super().save(*args, **kwargs) - if self.tracker.changed: - if update_fields := kwargs.get('update_fields', None): - for field in update_fields: - self.tracker.changed.pop(field, None) - else: - self.tracker.changed = {} + # Clear any tracked fields now that changes have been saved + update_fields = kwargs.get('update_fields', []) + self.tracker.clear(*update_fields) def __setattr__(self, name, value): if hasattr(self, "_initialized"): - change_tracking_fields = registry['counter_fields'][self.__class__] - if name in change_tracking_fields: - if name not in self.tracker.changed: + # Record any changes to a tracked field + if name in registry['counter_fields'][self.__class__]: + if name not in self.tracker: + # The attribute has been created or changed if name in self.__dict__: old_value = getattr(self, name) if value != old_value: - self.tracker.changed[name] = old_value + self.tracker.set(name, old_value) else: - self.tracker.changed[name] = DeferredAttribute - else: - if value == self.tracker.changed[name]: - self.tracker.changed.pop(name) + self.tracker.set(name, DeferredAttribute) + elif value == self.tracker.get(name): + # A previously changed attribute has been restored + self.tracker.clear(name) super().__setattr__(name, value)