Compare commits

..

2 Commits

Author SHA1 Message Date
Jeremy Stretch
e69fc9a4b4 Pass user object to EventContext 2026-01-23 15:22:52 -05:00
Jeremy Stretch
a6c6a58fb9 Initial work on #21260 2026-01-23 14:55:58 -05:00
6 changed files with 89 additions and 146 deletions

View File

@@ -1,32 +0,0 @@
from django.db import migrations
import mptt.managers
import mptt.models
def rebuild_mptt(apps, schema_editor):
"""
Rebuild the MPTT tree for ModuleBay to apply new ordering.
"""
ModuleBay = apps.get_model('dcim', 'ModuleBay')
# Set MPTTMeta with the correct order_insertion_by
class MPTTMeta:
order_insertion_by = ('module', 'name',)
ModuleBay.MPTTMeta = MPTTMeta
ModuleBay._mptt_meta = mptt.models.MPTTOptions(MPTTMeta)
manager = mptt.managers.TreeManager()
manager.model = ModuleBay
manager.contribute_to_class(ModuleBay, 'objects')
manager.rebuild()
class Migration(migrations.Migration):
dependencies = [
('dcim', '0225_gfk_indexes'),
]
operations = [
migrations.RunPython(code=rebuild_mptt, reverse_code=migrations.RunPython.noop),
]

View File

@@ -1273,7 +1273,7 @@ class ModuleBay(ModularComponentModel, TrackingModelMixin, MPTTModel):
verbose_name_plural = _('module bays') verbose_name_plural = _('module bays')
class MPTTMeta: class MPTTMeta:
order_insertion_by = ('module', 'name',) order_insertion_by = ('module',)
def clean(self): def clean(self):
super().clean() super().clean()

View File

@@ -5,7 +5,6 @@ from django.db import models
from django.db.models.signals import post_save from django.db.models.signals import post_save
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from jsonschema.exceptions import ValidationError as JSONValidationError from jsonschema.exceptions import ValidationError as JSONValidationError
from mptt.models import MPTTModel
from dcim.choices import * from dcim.choices import *
from dcim.utils import update_interface_bridges from dcim.utils import update_interface_bridges
@@ -330,7 +329,7 @@ class Module(TrackingModelMixin, PrimaryModel, ConfigContextModel):
component._location = self.device.location component._location = self.device.location
component._rack = self.device.rack component._rack = self.device.rack
if not issubclass(component_model, MPTTModel): if component_model is not ModuleBay:
component_model.objects.bulk_create(create_instances) component_model.objects.bulk_create(create_instances)
# Emit the post_save signal for each newly created object # Emit the post_save signal for each newly created object
for component in create_instances: for component in create_instances:
@@ -343,12 +342,11 @@ class Module(TrackingModelMixin, PrimaryModel, ConfigContextModel):
update_fields=None update_fields=None
) )
else: else:
# MPTT models must be saved individually to maintain tree structure # ModuleBays must be saved individually for MPTT
for instance in create_instances: for instance in create_instances:
instance.save() instance.save()
update_fields = ['module'] update_fields = ['module']
component_model.objects.bulk_update(update_instances, update_fields) component_model.objects.bulk_update(update_instances, update_fields)
# Emit the post_save signal for each updated object # Emit the post_save signal for each updated object
for component in update_instances: for component in update_instances:
@@ -361,9 +359,5 @@ class Module(TrackingModelMixin, PrimaryModel, ConfigContextModel):
update_fields=update_fields update_fields=update_fields
) )
# Rebuild MPTT tree if needed (bulk_update bypasses model save)
if issubclass(component_model, MPTTModel) and update_instances:
component_model.objects.rebuild()
# Interface bridges have to be set after interface instantiation # Interface bridges have to be set after interface instantiation
update_interface_bridges(self.device, self.module_type.interfacetemplates, self) update_interface_bridges(self.device, self.module_type.interfacetemplates, self)

View File

@@ -1,5 +1,5 @@
import logging import logging
from collections import defaultdict from collections import UserDict, defaultdict
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
@@ -12,7 +12,6 @@ from core.models import ObjectType
from netbox.config import get_config from netbox.config import get_config
from netbox.constants import RQ_QUEUE_DEFAULT from netbox.constants import RQ_QUEUE_DEFAULT
from netbox.models.features import has_feature from netbox.models.features import has_feature
from users.models import User
from utilities.api import get_serializer_for_model from utilities.api import get_serializer_for_model
from utilities.request import copy_safe_request from utilities.request import copy_safe_request
from utilities.rqworker import get_rq_retry from utilities.rqworker import get_rq_retry
@@ -23,6 +22,19 @@ from .models import EventRule
logger = logging.getLogger('netbox.events_processor') logger = logging.getLogger('netbox.events_processor')
class EventContext(UserDict):
"""
A custom dictionary that automatically serializes its associated object on demand.
"""
def __getitem__(self, item):
if item == 'data' and 'data' not in self:
data = serialize_for_event(self['object'])
self.__setitem__('data', data)
return data
return super().__getitem__(item)
def serialize_for_event(instance): def serialize_for_event(instance):
""" """
Return a serialized representation of the given instance suitable for use in a queued event. Return a serialized representation of the given instance suitable for use in a queued event.
@@ -66,37 +78,42 @@ def enqueue_event(queue, instance, request, event_type):
assert instance.pk is not None assert instance.pk is not None
key = f'{app_label}.{model_name}:{instance.pk}' key = f'{app_label}.{model_name}:{instance.pk}'
if key in queue: if key in queue:
queue[key]['data'] = serialize_for_event(instance)
queue[key]['snapshots']['postchange'] = get_snapshots(instance, event_type)['postchange'] queue[key]['snapshots']['postchange'] = get_snapshots(instance, event_type)['postchange']
# If the object is being deleted, update any prior "update" event to "delete" # If the object is being deleted, update any prior "update" event to "delete"
if event_type == OBJECT_DELETED: if event_type == OBJECT_DELETED:
queue[key]['event_type'] = event_type queue[key]['event_type'] = event_type
else: else:
queue[key] = { queue[key] = EventContext(
'object_type': ObjectType.objects.get_for_model(instance), object_type=ObjectType.objects.get_for_model(instance),
'object_id': instance.pk, object_id=instance.pk,
'event_type': event_type, object=instance,
'data': serialize_for_event(instance), event_type=event_type,
'snapshots': get_snapshots(instance, event_type), snapshots=get_snapshots(instance, event_type),
'request': request, request=request,
user=request.user,
# Legacy request attributes for backward compatibility # Legacy request attributes for backward compatibility
'username': request.user.username, username=request.user.username,
'request_id': request.id, request_id=request.id,
} )
# Force serialization of objects prior to them actually being deleted
if event_type == OBJECT_DELETED:
queue[key]['data'] = serialize_for_event(instance)
def process_event_rules(event_rules, object_type, event_type, data, username=None, snapshots=None, request=None): def process_event_rules(event_rules, object_type, event):
user = None # To be resolved from the username if needed """
Process a list of EventRules against an event.
"""
for event_rule in event_rules: for event_rule in event_rules:
# Evaluate event rule conditions (if any) # Evaluate event rule conditions (if any)
if not event_rule.eval_conditions(data): if not event_rule.eval_conditions(event['data']):
continue continue
# Compile event data # Compile event data
event_data = event_rule.action_data or {} event_data = event_rule.action_data or {}
event_data.update(data) event_data.update(event['data'])
# Webhooks # Webhooks
if event_rule.action_type == EventRuleActionChoices.WEBHOOK: if event_rule.action_type == EventRuleActionChoices.WEBHOOK:
@@ -109,50 +126,43 @@ def process_event_rules(event_rules, object_type, event_type, data, username=Non
params = { params = {
"event_rule": event_rule, "event_rule": event_rule,
"object_type": object_type, "object_type": object_type,
"event_type": event_type, "event_type": event['event_type'],
"data": event_data, "data": event_data,
"snapshots": snapshots, "snapshots": event['snapshots'],
"timestamp": timezone.now().isoformat(), "timestamp": timezone.now().isoformat(),
"username": username, "username": event['username'],
"retry": get_rq_retry() "retry": get_rq_retry()
} }
if snapshots: if 'snapshots' in event:
params["snapshots"] = snapshots params['snapshots'] = event['snapshots']
if request: if 'request' in event:
# Exclude FILES - webhooks don't need uploaded files, # Exclude FILES - webhooks don't need uploaded files,
# which can cause pickle errors with Pillow. # which can cause pickle errors with Pillow.
params["request"] = copy_safe_request(request, include_files=False) params['request'] = copy_safe_request(event['request'], include_files=False)
# Enqueue the task # Enqueue the task
rq_queue.enqueue( rq_queue.enqueue('extras.webhooks.send_webhook', **params)
"extras.webhooks.send_webhook",
**params
)
# Scripts # Scripts
elif event_rule.action_type == EventRuleActionChoices.SCRIPT: elif event_rule.action_type == EventRuleActionChoices.SCRIPT:
# Resolve the script from action parameters # Resolve the script from action parameters
script = event_rule.action_object.python_class() script = event_rule.action_object.python_class()
# Retrieve the User if not already resolved
if user is None:
user = User.objects.get(username=username)
# Enqueue a Job to record the script's execution # Enqueue a Job to record the script's execution
from extras.jobs import ScriptJob from extras.jobs import ScriptJob
params = { params = {
"instance": event_rule.action_object, "instance": event_rule.action_object,
"name": script.name, "name": script.name,
"user": user, "user": event['user'],
"data": event_data "data": event_data
} }
if snapshots: if 'snapshots' in event:
params["snapshots"] = snapshots params['snapshots'] = event['snapshots']
if request: if 'request' in event:
params["request"] = copy_safe_request(request) params['request'] = copy_safe_request(event['request'])
ScriptJob.enqueue(
**params # Enqueue the job
) ScriptJob.enqueue(**params)
# Notification groups # Notification groups
elif event_rule.action_type == EventRuleActionChoices.NOTIFICATION: elif event_rule.action_type == EventRuleActionChoices.NOTIFICATION:
@@ -161,7 +171,7 @@ def process_event_rules(event_rules, object_type, event_type, data, username=Non
object_type=object_type, object_type=object_type,
object_id=event_data['id'], object_id=event_data['id'],
object_repr=event_data.get('display'), object_repr=event_data.get('display'),
event_type=event_type event_type=event['event_type']
) )
else: else:
@@ -173,6 +183,8 @@ def process_event_rules(event_rules, object_type, event_type, data, username=Non
def process_event_queue(events): def process_event_queue(events):
""" """
Flush a list of object representation to RQ for EventRule processing. Flush a list of object representation to RQ for EventRule processing.
This is the default processor listed in EVENTS_PIPELINE.
""" """
events_cache = defaultdict(dict) events_cache = defaultdict(dict)
@@ -192,11 +204,7 @@ def process_event_queue(events):
process_event_rules( process_event_rules(
event_rules=event_rules, event_rules=event_rules,
object_type=object_type, object_type=object_type,
event_type=event['event_type'], event=event,
data=event['data'],
username=event['username'],
snapshots=event['snapshots'],
request=event['request'],
) )

View File

@@ -4,7 +4,7 @@ from django.dispatch import receiver
from core.events import * from core.events import *
from core.signals import job_end, job_start from core.signals import job_end, job_start
from extras.events import process_event_rules from extras.events import EventContext, process_event_rules
from extras.models import EventRule, Notification, Subscription from extras.models import EventRule, Notification, Subscription
from netbox.config import get_config from netbox.config import get_config
from netbox.models.features import has_feature from netbox.models.features import has_feature
@@ -102,14 +102,12 @@ def process_job_start_event_rules(sender, **kwargs):
enabled=True, enabled=True,
object_types=sender.object_type object_types=sender.object_type
) )
username = sender.user.username if sender.user else None event = EventContext(
process_event_rules(
event_rules=event_rules,
object_type=sender.object_type,
event_type=JOB_STARTED, event_type=JOB_STARTED,
data=sender.data, data=sender.data,
username=username user=sender.user,
) )
process_event_rules(event_rules, sender.object_type, event)
@receiver(job_end) @receiver(job_end)
@@ -122,14 +120,12 @@ def process_job_end_event_rules(sender, **kwargs):
enabled=True, enabled=True,
object_types=sender.object_type object_types=sender.object_type
) )
username = sender.user.username if sender.user else None event = EventContext(
process_event_rules(
event_rules=event_rules,
object_type=sender.object_type,
event_type=JOB_COMPLETED, event_type=JOB_COMPLETED,
data=sender.data, data=sender.data,
username=username user=sender.user,
) )
process_event_rules(event_rules, sender.object_type, event)
# #

View File

@@ -438,12 +438,30 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
""" """
return object_form.save() return object_form.save()
def _process_import_records(self, form, request, records, prefetched_objects): def create_and_update_objects(self, form, request):
"""
Process CSV import records and save objects.
"""
saved_objects = [] saved_objects = []
records = list(form.cleaned_data['data'])
# Prefetch objects to be updated, if any
prefetch_ids = [int(record['id']) for record in records if record.get('id')]
# check for duplicate IDs
duplicate_pks = [pk for pk, count in Counter(prefetch_ids).items() if count > 1]
if duplicate_pks:
error_msg = _(
"Duplicate objects found: {model} with ID(s) {ids} appears multiple times"
).format(
model=title(self.queryset.model._meta.verbose_name),
ids=', '.join(str(pk) for pk in sorted(duplicate_pks))
)
raise ValidationError(error_msg)
prefetched_objects = {
obj.pk: obj
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
} if prefetch_ids else {}
for i, record in enumerate(records, start=1): for i, record in enumerate(records, start=1):
object_id = int(record.pop('id')) if record.get('id') else None object_id = int(record.pop('id')) if record.get('id') else None
@@ -508,38 +526,6 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
return saved_objects return saved_objects
def create_and_update_objects(self, form, request):
records = list(form.cleaned_data['data'])
# Prefetch objects to be updated, if any
prefetch_ids = [int(record['id']) for record in records if record.get('id')]
# check for duplicate IDs
duplicate_pks = [pk for pk, count in Counter(prefetch_ids).items() if count > 1]
if duplicate_pks:
error_msg = _(
"Duplicate objects found: {model} with ID(s) {ids} appears multiple times"
).format(
model=title(self.queryset.model._meta.verbose_name),
ids=', '.join(str(pk) for pk in sorted(duplicate_pks))
)
raise ValidationError(error_msg)
prefetched_objects = {
obj.pk: obj
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
} if prefetch_ids else {}
# For MPTT models, delay tree updates until all saves are complete
if issubclass(self.queryset.model, MPTTModel):
with self.queryset.model.objects.delay_mptt_updates():
saved_objects = self._process_import_records(form, request, records, prefetched_objects)
self.queryset.model.objects.rebuild()
else:
saved_objects = self._process_import_records(form, request, records, prefetched_objects)
return saved_objects
# #
# Request handlers # Request handlers
# #
@@ -909,18 +895,9 @@ class BulkRenameView(GetReturnURLMixin, BaseMultiObjectView):
renamed_pks = self._rename_objects(form, selected_objects) renamed_pks = self._rename_objects(form, selected_objects)
if '_apply' in request.POST: if '_apply' in request.POST:
# For MPTT models, delay tree updates until all saves are complete for obj in selected_objects:
if issubclass(self.queryset.model, MPTTModel): setattr(obj, self.field_name, obj.new_name)
with self.queryset.model.objects.delay_mptt_updates(): obj.save()
for obj in selected_objects:
setattr(obj, self.field_name, obj.new_name)
obj.save()
self.queryset.model.objects.rebuild()
else:
for obj in selected_objects:
setattr(obj, self.field_name, obj.new_name)
obj.save()
# Enforce constrained permissions # Enforce constrained permissions
if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects): if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects):