Create API for Tracker; add comments

This commit is contained in:
Jeremy Stretch 2023-07-24 13:18:05 -04:00
parent c2ef15607e
commit 073ad5329a
2 changed files with 49 additions and 20 deletions

View File

@ -14,9 +14,9 @@ def post_save_receiver_counter(counter_instance, sender, instance, created, **kw
# not created so check if field has changed # not created so check if field has changed
field_name = f"{counter_instance.foreign_key_field.name}_id" field_name = f"{counter_instance.foreign_key_field.name}_id"
if field_name in instance.tracker.changed: if field_name in instance.tracker:
new_value = getattr(instance, field_name, None) new_value = getattr(instance, field_name, None)
old_value = instance.tracker.changed[field_name] old_value = instance.tracker.get(field_name)
if (new_value is not None) and (new_value != old_value): if (new_value is not None) and (new_value != old_value):
counter_instance.adjust_count(new_value, 1) counter_instance.adjust_count(new_value, 1)
counter_instance.adjust_count(old_value, -1) counter_instance.adjust_count(old_value, -1)

View File

@ -4,46 +4,75 @@ from netbox.registry import registry
class Tracker: class Tracker:
def __init__(self, instance): """
self.instance = instance An ephemeral instance employed to record which tracked fields on an instance have been modified.
self.changed = {} """
def __init__(self):
self._changed_fields = {}
def __contains__(self, item):
return item in self._changed_fields
def set(self, name, value):
"""
Mark an attribute as having been changed and record its original value.
"""
self._changed_fields[name] = value
def get(self, name):
"""
Return the original value of a changed field. Raises KeyError if name is not found.
"""
return self._changed_fields[name]
def clear(self, *names):
"""
Clear any fields that were recorded as having been changed.
"""
for name in names:
self._changed_fields.pop(name, None)
else:
self._changed_fields = {}
class TrackingModelMixin: class TrackingModelMixin:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Mark the instance as initialized, to enable our custom __setattr__()
self._initialized = True self._initialized = True
@property @property
def tracker(self): def tracker(self):
"""
Return the Tracker instance for this instance, first creating it if necessary.
"""
if not hasattr(self._state, "_tracker"): if not hasattr(self._state, "_tracker"):
self._state._tracker = Tracker(self) self._state._tracker = Tracker()
return self._state._tracker return self._state._tracker
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
super().save(*args, **kwargs) super().save(*args, **kwargs)
if self.tracker.changed: # Clear any tracked fields now that changes have been saved
if update_fields := kwargs.get('update_fields', None): update_fields = kwargs.get('update_fields', [])
for field in update_fields: self.tracker.clear(*update_fields)
self.tracker.changed.pop(field, None)
else:
self.tracker.changed = {}
def __setattr__(self, name, value): def __setattr__(self, name, value):
if hasattr(self, "_initialized"): if hasattr(self, "_initialized"):
change_tracking_fields = registry['counter_fields'][self.__class__] # Record any changes to a tracked field
if name in change_tracking_fields: if name in registry['counter_fields'][self.__class__]:
if name not in self.tracker.changed: if name not in self.tracker:
# The attribute has been created or changed
if name in self.__dict__: if name in self.__dict__:
old_value = getattr(self, name) old_value = getattr(self, name)
if value != old_value: if value != old_value:
self.tracker.changed[name] = old_value self.tracker.set(name, old_value)
else: else:
self.tracker.changed[name] = DeferredAttribute self.tracker.set(name, DeferredAttribute)
else: elif value == self.tracker.get(name):
if value == self.tracker.changed[name]: # A previously changed attribute has been restored
self.tracker.changed.pop(name) self.tracker.clear(name)
super().__setattr__(name, value) super().__setattr__(name, value)