From 5375c6108e6330580e3d335645dab86b8cebe271 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 17 May 2023 16:08:43 -0700 Subject: [PATCH] 6347 track item move --- netbox/netbox/models/__init__.py | 3 ++ netbox/utilities/counter.py | 27 +++++++++++++--- netbox/utilities/mixins.py | 54 ++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 netbox/utilities/mixins.py diff --git a/netbox/netbox/models/__init__.py b/netbox/netbox/models/__init__.py index c0f679e4f..40a4dda71 100644 --- a/netbox/netbox/models/__init__.py +++ b/netbox/netbox/models/__init__.py @@ -7,6 +7,8 @@ from mptt.models import MPTTModel, TreeForeignKey from netbox.models.features import * from utilities.mptt import TreeManager from utilities.querysets import RestrictedQuerySet +from utilities.mixins import TrackingModelMixin + __all__ = ( 'ChangeLoggedModel', @@ -18,6 +20,7 @@ __all__ = ( class NetBoxFeatureSet( + TrackingModelMixin, ChangeLoggingMixin, CustomFieldsMixin, CustomLinksMixin, diff --git a/netbox/utilities/counter.py b/netbox/utilities/counter.py index c929fc358..ccba80587 100644 --- a/netbox/utilities/counter.py +++ b/netbox/utilities/counter.py @@ -2,6 +2,7 @@ from django.db.models import F from django.db.models.signals import post_delete, post_save, pre_save from .fields import CounterCacheField +from .mixins import TrackingModelMixin counters = {} @@ -21,6 +22,11 @@ class Counter(object): self.child_model = self.foreign_key_field.model self.parent_model = self.foreign_key_field.related_model + # add the field to be tracked for changes incase of update + field_name = f"{self.foreign_key_field.name}_id" + if field_name not in self.child_model.change_tracking_fields: + self.child_model.change_tracking_fields.append(field_name) + self.connect() def validate(self): @@ -29,6 +35,10 @@ class Counter(object): raise TypeError( f"{self.counter_name} should be a CounterCacheField on {self.parent_model}, but is {type(counter_field)}" ) + if not isinstance(parent_model, TrackingModelMixin): + raise TypeError( + f"{self.parent_model} should be derived from TrackingModelMixin" + ) def connect(self): """ @@ -39,14 +49,24 @@ class Counter(object): def post_save_receiver_counter(sender, instance, created, **kwargs): if created: - self.adjust_count(instance, 1) + self.adjust_count(self.parent_id(instance), 1) + return + + # not created so check if field has changed + field_name = f"{self.foreign_key_field.name}_id" + if field_name in instance.tracker.changed: + new_value = getattr(instance, field_name, None) + old_value = instance.tracker.changed[field_name] + if (new_value is not None) and (new_value != old_value): + self.adjust_count(new_value, 1) + self.adjust_count(old_value, -1) post_save.connect( post_save_receiver_counter, sender=self.child_model, weak=False, dispatch_uid=f'{counted_name}_post_save' ) def post_delete_receiver_counter(sender, instance, **kwargs): - self.adjust_count(instance, -1) + self.adjust_count(self.parent_id(instance), -1) post_delete.connect( post_delete_receiver_counter, @@ -63,8 +83,7 @@ class Counter(object): def set_counter_field(self, parent_id, value): return self.parent_model.objects.filter(pk=parent_id).update(**{self.counter_name: value}) - def adjust_count(self, child, amount): - parent_id = self.parent_id(child) + def adjust_count(self, parent_id, amount): return self.set_counter_field(parent_id, F(self.counter_name) + amount) diff --git a/netbox/utilities/mixins.py b/netbox/utilities/mixins.py new file mode 100644 index 000000000..971680894 --- /dev/null +++ b/netbox/utilities/mixins.py @@ -0,0 +1,54 @@ +from django.db.models.query_utils import DeferredAttribute + + +class Tracker(object): + def __init__(self, instance): + self.instance = instance + self.newly_created = False + self.changed = {} + self.tracked_fields = self.instance.change_tracking_fields + + +class TrackingModelMixin(object): + change_tracking_fields = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._initialized = True + + @property + def tracker(self): + if hasattr(self._state, "_tracker"): + tracker = self._state._tracker + else: + tracker = self._state._tracker = Tracker(self) + return tracker + + def save(self, *args, **kwargs): + if not self.change_tracking_fields: + return super().save(*args, **kwargs) + + self.tracker.newly_created = self._state.adding + super().save(*args, **kwargs) + if self.tracker.changed: + if update_fields := kwargs.get('update_fields', None): + for field in update_fields: + self.tracker.changed.pop(field, None) + else: + self.tracker.changed = {} + + def __setattr__(self, name, value): + if hasattr(self, "_initialized") and self.change_tracking_fields: + if name in self.tracker.tracked_fields: + if name not in self.tracker.changed: + if name in self.__dict__: + old_value = getattr(self, name) + if value != old_value: + self.tracker.changed[name] = old_value + else: + self.tracker.changed[name] = DeferredAttribute + else: + if value == self.tracker.changed[name]: + self.tracker.changed.pop(name) + + super().__setattr__(name, value)