Closes #20003: Introduce mechanism to register callbacks for webhook context

This commit is contained in:
Jeremy Stretch 2025-08-05 14:46:36 -04:00
parent 0c70e9e140
commit b24d30261c
8 changed files with 76 additions and 23 deletions

View File

@ -136,7 +136,7 @@ def handle_changed_object(sender, instance, **kwargs):
# Enqueue the object for event processing # Enqueue the object for event processing
queue = events_queue.get() queue = events_queue.get()
enqueue_event(queue, instance, request.user, request.id, event_type) enqueue_event(queue, instance, request, event_type)
events_queue.set(queue) events_queue.set(queue)
# Increment metric counters # Increment metric counters
@ -220,7 +220,7 @@ def handle_deleted_object(sender, instance, **kwargs):
# Enqueue the object for event processing # Enqueue the object for event processing
queue = events_queue.get() queue = events_queue.get()
enqueue_event(queue, instance, request.user, request.id, OBJECT_DELETED) enqueue_event(queue, instance, request, OBJECT_DELETED)
events_queue.set(queue) events_queue.set(queue)
# Increment metric counters # Increment metric counters

View File

@ -14,6 +14,7 @@ from netbox.constants import RQ_QUEUE_DEFAULT
from netbox.models.features import has_feature from netbox.models.features import has_feature
from users.models import User from users.models import User
from utilities.api import get_serializer_for_model from utilities.api import get_serializer_for_model
from utilities.request import copy_safe_request
from utilities.rqworker import get_rq_retry from utilities.rqworker import get_rq_retry
from utilities.serialization import serialize_object from utilities.serialization import serialize_object
from .choices import EventRuleActionChoices from .choices import EventRuleActionChoices
@ -50,7 +51,7 @@ def get_snapshots(instance, event_type):
return snapshots return snapshots
def enqueue_event(queue, instance, user, request_id, event_type): def enqueue_event(queue, instance, request, event_type):
""" """
Enqueue a serialized representation of a created/updated/deleted object for the processing of Enqueue a serialized representation of a created/updated/deleted object for the processing of
events once the request has completed. events once the request has completed.
@ -77,12 +78,14 @@ def enqueue_event(queue, instance, user, request_id, event_type):
'event_type': event_type, 'event_type': event_type,
'data': serialize_for_event(instance), 'data': serialize_for_event(instance),
'snapshots': get_snapshots(instance, event_type), 'snapshots': get_snapshots(instance, event_type),
'username': user.username, 'request': request,
'request_id': request_id # Legacy request attributes for backward compatibility
'username': request.user.username,
'request_id': request.id,
} }
def process_event_rules(event_rules, object_type, event_type, data, username=None, snapshots=None, request_id=None): def process_event_rules(event_rules, object_type, event_type, data, username=None, snapshots=None, request=None):
user = User.objects.get(username=username) if username else None user = User.objects.get(username=username) if username else None
for event_rule in event_rules: for event_rule in event_rules:
@ -105,7 +108,7 @@ def process_event_rules(event_rules, object_type, event_type, data, username=Non
# Compile the task parameters # Compile the task parameters
params = { params = {
"event_rule": event_rule, "event_rule": event_rule,
"model_name": object_type.model, "object_type": object_type,
"event_type": event_type, "event_type": event_type,
"data": event_data, "data": event_data,
"snapshots": snapshots, "snapshots": snapshots,
@ -115,8 +118,8 @@ def process_event_rules(event_rules, object_type, event_type, data, username=Non
} }
if snapshots: if snapshots:
params["snapshots"] = snapshots params["snapshots"] = snapshots
if request_id: if request:
params["request_id"] = request_id params["request"] = copy_safe_request(request)
# Enqueue the task # Enqueue the task
rq_queue.enqueue( rq_queue.enqueue(
@ -180,7 +183,7 @@ def process_event_queue(events):
data=event['data'], data=event['data'],
username=event['username'], username=event['username'],
snapshots=event['snapshots'], snapshots=event['snapshots'],
request_id=event['request_id'] request=event['request'],
) )

View File

@ -3,6 +3,7 @@ import uuid
from unittest.mock import patch from unittest.mock import patch
import django_rq import django_rq
from django.contrib.contenttypes.models import ContentType
from django.http import HttpResponse from django.http import HttpResponse
from django.test import RequestFactory from django.test import RequestFactory
from django.urls import reverse from django.urls import reverse
@ -135,7 +136,7 @@ class EventRuleTest(APITestCase):
job = self.queue.jobs[0] job = self.queue.jobs[0]
self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 1')) self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 1'))
self.assertEqual(job.kwargs['event_type'], OBJECT_CREATED) self.assertEqual(job.kwargs['event_type'], OBJECT_CREATED)
self.assertEqual(job.kwargs['model_name'], 'site') self.assertEqual(job.kwargs['object_type'], ContentType.objects.get_for_model(Site))
self.assertEqual(job.kwargs['data']['id'], response.data['id']) self.assertEqual(job.kwargs['data']['id'], response.data['id'])
self.assertEqual(job.kwargs['data']['foo'], 1) self.assertEqual(job.kwargs['data']['foo'], 1)
self.assertEqual(len(job.kwargs['data']['tags']), len(response.data['tags'])) self.assertEqual(len(job.kwargs['data']['tags']), len(response.data['tags']))
@ -186,7 +187,7 @@ class EventRuleTest(APITestCase):
for i, job in enumerate(self.queue.jobs): for i, job in enumerate(self.queue.jobs):
self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 1')) self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 1'))
self.assertEqual(job.kwargs['event_type'], OBJECT_CREATED) self.assertEqual(job.kwargs['event_type'], OBJECT_CREATED)
self.assertEqual(job.kwargs['model_name'], 'site') self.assertEqual(job.kwargs['object_type'], ContentType.objects.get_for_model(Site))
self.assertEqual(job.kwargs['data']['id'], response.data[i]['id']) self.assertEqual(job.kwargs['data']['id'], response.data[i]['id'])
self.assertEqual(job.kwargs['data']['foo'], 1) self.assertEqual(job.kwargs['data']['foo'], 1)
self.assertEqual(len(job.kwargs['data']['tags']), len(response.data[i]['tags'])) self.assertEqual(len(job.kwargs['data']['tags']), len(response.data[i]['tags']))
@ -218,7 +219,7 @@ class EventRuleTest(APITestCase):
job = self.queue.jobs[0] job = self.queue.jobs[0]
self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 2')) self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 2'))
self.assertEqual(job.kwargs['event_type'], OBJECT_UPDATED) self.assertEqual(job.kwargs['event_type'], OBJECT_UPDATED)
self.assertEqual(job.kwargs['model_name'], 'site') self.assertEqual(job.kwargs['object_type'], ContentType.objects.get_for_model(Site))
self.assertEqual(job.kwargs['data']['id'], site.pk) self.assertEqual(job.kwargs['data']['id'], site.pk)
self.assertEqual(job.kwargs['data']['foo'], 2) self.assertEqual(job.kwargs['data']['foo'], 2)
self.assertEqual(len(job.kwargs['data']['tags']), len(response.data['tags'])) self.assertEqual(len(job.kwargs['data']['tags']), len(response.data['tags']))
@ -275,7 +276,7 @@ class EventRuleTest(APITestCase):
for i, job in enumerate(self.queue.jobs): for i, job in enumerate(self.queue.jobs):
self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 2')) self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 2'))
self.assertEqual(job.kwargs['event_type'], OBJECT_UPDATED) self.assertEqual(job.kwargs['event_type'], OBJECT_UPDATED)
self.assertEqual(job.kwargs['model_name'], 'site') self.assertEqual(job.kwargs['object_type'], ContentType.objects.get_for_model(Site))
self.assertEqual(job.kwargs['data']['id'], data[i]['id']) self.assertEqual(job.kwargs['data']['id'], data[i]['id'])
self.assertEqual(job.kwargs['data']['foo'], 2) self.assertEqual(job.kwargs['data']['foo'], 2)
self.assertEqual(len(job.kwargs['data']['tags']), len(response.data[i]['tags'])) self.assertEqual(len(job.kwargs['data']['tags']), len(response.data[i]['tags']))
@ -302,7 +303,7 @@ class EventRuleTest(APITestCase):
job = self.queue.jobs[0] job = self.queue.jobs[0]
self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 3')) self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 3'))
self.assertEqual(job.kwargs['event_type'], OBJECT_DELETED) self.assertEqual(job.kwargs['event_type'], OBJECT_DELETED)
self.assertEqual(job.kwargs['model_name'], 'site') self.assertEqual(job.kwargs['object_type'], ContentType.objects.get_for_model(Site))
self.assertEqual(job.kwargs['data']['id'], site.pk) self.assertEqual(job.kwargs['data']['id'], site.pk)
self.assertEqual(job.kwargs['data']['foo'], 3) self.assertEqual(job.kwargs['data']['foo'], 3)
self.assertEqual(job.kwargs['snapshots']['prechange']['name'], 'Site 1') self.assertEqual(job.kwargs['snapshots']['prechange']['name'], 'Site 1')
@ -336,7 +337,7 @@ class EventRuleTest(APITestCase):
for i, job in enumerate(self.queue.jobs): for i, job in enumerate(self.queue.jobs):
self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 3')) self.assertEqual(job.kwargs['event_rule'], EventRule.objects.get(name='Event Rule 3'))
self.assertEqual(job.kwargs['event_type'], OBJECT_DELETED) self.assertEqual(job.kwargs['event_type'], OBJECT_DELETED)
self.assertEqual(job.kwargs['model_name'], 'site') self.assertEqual(job.kwargs['object_type'], ContentType.objects.get_for_model(Site))
self.assertEqual(job.kwargs['data']['id'], sites[i].pk) self.assertEqual(job.kwargs['data']['id'], sites[i].pk)
self.assertEqual(job.kwargs['data']['foo'], 3) self.assertEqual(job.kwargs['data']['foo'], 3)
self.assertEqual(job.kwargs['snapshots']['prechange']['name'], sites[i].name) self.assertEqual(job.kwargs['snapshots']['prechange']['name'], sites[i].name)
@ -368,18 +369,23 @@ class EventRuleTest(APITestCase):
self.assertEqual(body['request_id'], str(request_id)) self.assertEqual(body['request_id'], str(request_id))
self.assertEqual(body['data']['name'], 'Site 1') self.assertEqual(body['data']['name'], 'Site 1')
self.assertEqual(body['data']['foo'], 1) self.assertEqual(body['data']['foo'], 1)
self.assertEqual(body['context']['foo'], 123) # From netbox.tests.dummy_plugin
return HttpResponse() return HttpResponse()
# Create a dummy request
request = RequestFactory().get(reverse('dcim:site_add'))
request.id = request_id
request.user = self.user
# Enqueue a webhook for processing # Enqueue a webhook for processing
webhooks_queue = {} webhooks_queue = {}
site = Site.objects.create(name='Site 1', slug='site-1') site = Site.objects.create(name='Site 1', slug='site-1')
enqueue_event( enqueue_event(
webhooks_queue, webhooks_queue,
instance=site, instance=site,
user=self.user, request=request,
request_id=request_id, event_type=OBJECT_CREATED,
event_type=OBJECT_CREATED
) )
flush_events(list(webhooks_queue.values())) flush_events(list(webhooks_queue.values()))

View File

@ -6,12 +6,28 @@ import requests
from django_rq import job from django_rq import job
from jinja2.exceptions import TemplateError from jinja2.exceptions import TemplateError
from netbox.registry import registry
from utilities.proxy import resolve_proxies from utilities.proxy import resolve_proxies
from .constants import WEBHOOK_EVENT_TYPES from .constants import WEBHOOK_EVENT_TYPES
__all__ = (
'generate_signature',
'register_webhook_callback',
'send_webhook',
)
logger = logging.getLogger('netbox.webhooks') logger = logging.getLogger('netbox.webhooks')
def register_webhook_callback(func):
"""
Register a function as a webhook callback.
"""
registry['webhook_callbacks'].append(func)
logger.debug(f'Registered webhook callback {func.__module__}.{func.__name__}')
return func
def generate_signature(request_body, secret): def generate_signature(request_body, secret):
""" """
Return a cryptographic signature that can be used to verify the authenticity of webhook data. Return a cryptographic signature that can be used to verify the authenticity of webhook data.
@ -25,7 +41,7 @@ def generate_signature(request_body, secret):
@job('default') @job('default')
def send_webhook(event_rule, model_name, event_type, data, timestamp, username, request_id=None, snapshots=None): def send_webhook(event_rule, object_type, event_type, data, timestamp, username, request=None, snapshots=None):
""" """
Make a POST request to the defined Webhook Make a POST request to the defined Webhook
""" """
@ -35,9 +51,9 @@ def send_webhook(event_rule, model_name, event_type, data, timestamp, username,
context = { context = {
'event': WEBHOOK_EVENT_TYPES.get(event_type, event_type), 'event': WEBHOOK_EVENT_TYPES.get(event_type, event_type),
'timestamp': timestamp, 'timestamp': timestamp,
'model': model_name, 'model': object_type.model,
'username': username, 'username': username,
'request_id': request_id, 'request_id': request.id if request else None,
'data': data, 'data': data,
} }
if snapshots: if snapshots:
@ -45,6 +61,18 @@ def send_webhook(event_rule, model_name, event_type, data, timestamp, username,
'snapshots': snapshots 'snapshots': snapshots
}) })
# Add any additional context from plugins
callback_data = {}
for callback in registry['webhook_callbacks']:
try:
if ret := callback(object_type, event_type, data, request):
callback_data.update(**ret)
except Exception as e:
logger.warning(f"Caught exception when processing callback {callback}: {e}")
pass
if callback_data:
context['context'] = callback_data
# Build the headers for the HTTP request # Build the headers for the HTTP request
headers = { headers = {
'Content-Type': webhook.http_content_type, 'Content-Type': webhook.http_content_type,

View File

@ -34,5 +34,6 @@ registry = Registry({
'system_jobs': dict(), 'system_jobs': dict(),
'tables': collections.defaultdict(dict), 'tables': collections.defaultdict(dict),
'views': collections.defaultdict(dict), 'views': collections.defaultdict(dict),
'webhook_callbacks': list(),
'widgets': dict(), 'widgets': dict(),
}) })

View File

@ -24,7 +24,7 @@ class DummyPluginConfig(PluginConfig):
def ready(self): def ready(self):
super().ready() super().ready()
from . import jobs # noqa: F401 from . import jobs, webhook_callbacks # noqa: F401
config = DummyPluginConfig config = DummyPluginConfig

View File

@ -0,0 +1,8 @@
from extras.webhooks import register_webhook_callback
@register_webhook_callback
def set_context(object_type, event_type, data, request):
return {
'foo': 123,
}

View File

@ -10,6 +10,7 @@ from core.models import ObjectType
from netbox.tests.dummy_plugin import config as dummy_config from netbox.tests.dummy_plugin import config as dummy_config
from netbox.tests.dummy_plugin.data_backends import DummyBackend from netbox.tests.dummy_plugin.data_backends import DummyBackend
from netbox.tests.dummy_plugin.jobs import DummySystemJob from netbox.tests.dummy_plugin.jobs import DummySystemJob
from netbox.tests.dummy_plugin.webhook_callbacks import set_context
from netbox.plugins.navigation import PluginMenu from netbox.plugins.navigation import PluginMenu
from netbox.plugins.utils import get_plugin_config from netbox.plugins.utils import get_plugin_config
from netbox.graphql.schema import Query from netbox.graphql.schema import Query
@ -220,3 +221,9 @@ class PluginTest(TestCase):
Check that events pipeline is registered. Check that events pipeline is registered.
""" """
self.assertIn('netbox.tests.dummy_plugin.events.process_events_queue', settings.EVENTS_PIPELINE) self.assertIn('netbox.tests.dummy_plugin.events.process_events_queue', settings.EVENTS_PIPELINE)
def test_webhook_callbacks(self):
"""
Test the registration of webhook callbacks.
"""
self.assertIn(set_context, registry['webhook_callbacks'])