Move change_tracking_fields from TrackingModelMixin to the registry

This commit is contained in:
Jeremy Stretch 2023-07-24 10:24:09 -04:00
parent 4156fb31b8
commit 0ee05a6810
4 changed files with 16 additions and 12 deletions

View File

@ -8,6 +8,10 @@ The registry can be inspected by importing `registry` from `extras.registry`.
## Stores ## Stores
### `counter_fields`
A dictionary mapping of models to foreign keys with which cached counter fields are associated.
### `data_backends` ### `data_backends`
A dictionary mapping data backend types to their respective classes. These are used to interact with [remote data sources](../models/core/datasource.md). A dictionary mapping data backend types to their respective classes. These are used to interact with [remote data sources](../models/core/datasource.md).

View File

@ -21,6 +21,7 @@ class Registry(dict):
# Initialize the global registry # Initialize the global registry
registry = Registry({ registry = Registry({
'counter_fields': collections.defaultdict(set),
'data_backends': dict(), 'data_backends': dict(),
'denormalized_fields': collections.defaultdict(list), 'denormalized_fields': collections.defaultdict(list),
'model_features': dict(), 'model_features': dict(),

View File

@ -1,10 +1,10 @@
from django.apps import apps from django.apps import apps
from django.db.models import F from django.db.models import F
from django.db.models.signals import post_delete, post_save, pre_save from django.db.models.signals import post_delete, post_save
from functools import partial from functools import partial
from netbox.registry import registry
from .fields import CounterCacheField from .fields import CounterCacheField
from .mixins import TrackingModelMixin
def post_save_receiver_counter(counter_instance, sender, instance, created, **kwargs): def post_save_receiver_counter(counter_instance, sender, instance, created, **kwargs):
@ -41,10 +41,9 @@ class Counter:
self.child_model = self.foreign_key_field.model self.child_model = self.foreign_key_field.model
self.parent_model = self.foreign_key_field.related_model self.parent_model = self.foreign_key_field.related_model
# add the field to be tracked for changes incase of update # add the field to be tracked for changes in case of update
field_name = f"{self.foreign_key_field.name}_id" change_tracking_fields = registry['counter_fields'][self.child_model]
if hasattr(self.child_model, 'change_tracking_fields') and field_name not in self.child_model.change_tracking_fields: change_tracking_fields.add(f"{self.foreign_key_field.name}_id")
self.child_model.change_tracking_fields.append(field_name)
self.connect() self.connect()

View File

@ -1,5 +1,7 @@
from django.db.models.query_utils import DeferredAttribute from django.db.models.query_utils import DeferredAttribute
from netbox.registry import registry
class Tracker: class Tracker:
def __init__(self, instance): def __init__(self, instance):
@ -8,7 +10,6 @@ class Tracker:
class TrackingModelMixin: class TrackingModelMixin:
change_tracking_fields = []
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -21,10 +22,8 @@ class TrackingModelMixin:
return self._state._tracker return self._state._tracker
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.change_tracking_fields:
return super().save(*args, **kwargs)
super().save(*args, **kwargs) super().save(*args, **kwargs)
if self.tracker.changed: if self.tracker.changed:
if update_fields := kwargs.get('update_fields', None): if update_fields := kwargs.get('update_fields', None):
for field in update_fields: for field in update_fields:
@ -33,8 +32,9 @@ class TrackingModelMixin:
self.tracker.changed = {} self.tracker.changed = {}
def __setattr__(self, name, value): def __setattr__(self, name, value):
if hasattr(self, "_initialized") and self.change_tracking_fields: if hasattr(self, "_initialized"):
if name in self.tracker.instance.change_tracking_fields: change_tracking_fields = registry['counter_fields'][self.__class__]
if name in change_tracking_fields:
if name not in self.tracker.changed: if name not in self.tracker.changed:
if name in self.__dict__: if name in self.__dict__:
old_value = getattr(self, name) old_value = getattr(self, name)