Reference database object by GFK when running scripts & reports via API

This commit is contained in:
jeremystretch 2023-03-27 16:46:47 -04:00
parent ccb09b0f7b
commit 69fd138533
4 changed files with 64 additions and 35 deletions

View File

@ -17,8 +17,8 @@ from core.choices import JobStatusChoices
from core.models import Job from core.models import Job
from extras import filtersets from extras import filtersets
from extras.models import * from extras.models import *
from extras.reports import get_report, run_report from extras.reports import get_module_and_report, get_report, run_report
from extras.scripts import get_script, run_script from extras.scripts import get_module_and_script, get_script, run_script
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.features import SyncedDataMixin from netbox.api.features import SyncedDataMixin
from netbox.api.metadata import ContentTypeMetadata from netbox.api.metadata import ContentTypeMetadata
@ -171,19 +171,17 @@ class ReportViewSet(ViewSet):
exclude_from_schema = True exclude_from_schema = True
lookup_value_regex = '[^/]+' # Allow dots lookup_value_regex = '[^/]+' # Allow dots
def _retrieve_report(self, pk): def _get_report(self, pk):
try:
# Read the PK as "<module>.<report>"
if '.' not in pk:
raise Http404
module_name, report_name = pk.split('.', maxsplit=1) module_name, report_name = pk.split('.', maxsplit=1)
except ValueError:
raise Http404
# Raise a 404 on an invalid Report module/name module, report = get_module_and_report(module_name, report_name)
report = get_report(module_name, report_name)
if report is None: if report is None:
raise Http404 raise Http404
return report return module, report
def list(self, request): def list(self, request):
""" """
@ -216,13 +214,13 @@ class ReportViewSet(ViewSet):
""" """
Retrieve a single Report identified as "<module>.<report>". Retrieve a single Report identified as "<module>.<report>".
""" """
module, report = self._get_report(pk)
# Retrieve the Report and Job, if any. # Retrieve the Report and Job, if any.
report = self._retrieve_report(pk) object_type = ContentType.objects.get(app_label='extras', model='reportmodule')
report_content_type = ContentType.objects.get(app_label='extras', model='report')
report.result = Job.objects.filter( report.result = Job.objects.filter(
object_type=report_content_type, object_type=object_type,
name=report.full_name, name=report.name,
status__in=JobStatusChoices.TERMINAL_STATE_CHOICES status__in=JobStatusChoices.TERMINAL_STATE_CHOICES
).first() ).first()
@ -246,14 +244,14 @@ class ReportViewSet(ViewSet):
raise RQWorkerNotRunningException() raise RQWorkerNotRunningException()
# Retrieve and run the Report. This will create a new Job. # Retrieve and run the Report. This will create a new Job.
report = self._retrieve_report(pk) module, report = self._get_report(pk)
input_serializer = serializers.ReportInputSerializer(data=request.data) input_serializer = serializers.ReportInputSerializer(data=request.data)
if input_serializer.is_valid(): if input_serializer.is_valid():
report.result = Job.enqueue_job( report.result = Job.enqueue(
run_report, run_report,
name=report.full_name, instance=module,
obj_type=ContentType.objects.get_for_model(Report), name=report.class_name,
user=request.user, user=request.user,
job_timeout=report.job_timeout, job_timeout=report.job_timeout,
schedule_at=input_serializer.validated_data.get('schedule_at'), schedule_at=input_serializer.validated_data.get('schedule_at'),
@ -276,11 +274,16 @@ class ScriptViewSet(ViewSet):
lookup_value_regex = '[^/]+' # Allow dots lookup_value_regex = '[^/]+' # Allow dots
def _get_script(self, pk): def _get_script(self, pk):
try:
module_name, script_name = pk.split('.', maxsplit=1) module_name, script_name = pk.split('.', maxsplit=1)
script = get_script(module_name, script_name) except ValueError:
raise Http404
module, script = get_module_and_script(module_name, script_name)
if script is None: if script is None:
raise Http404 raise Http404
return script
return module, script
def list(self, request): def list(self, request):
@ -306,11 +309,11 @@ class ScriptViewSet(ViewSet):
return Response(serializer.data) return Response(serializer.data)
def retrieve(self, request, pk): def retrieve(self, request, pk):
script = self._get_script(pk) module, script = self._get_script(pk)
script_content_type = ContentType.objects.get(app_label='extras', model='script') object_type = ContentType.objects.get(app_label='extras', model='scriptmodule')
script.result = Job.objects.filter( script.result = Job.objects.filter(
object_type=script_content_type, object_type=object_type,
name=script.full_name, name=script.name,
status__in=JobStatusChoices.TERMINAL_STATE_CHOICES status__in=JobStatusChoices.TERMINAL_STATE_CHOICES
).first() ).first()
serializer = serializers.ScriptDetailSerializer(script, context={'request': request}) serializer = serializers.ScriptDetailSerializer(script, context={'request': request})
@ -325,7 +328,7 @@ class ScriptViewSet(ViewSet):
if not request.user.has_perm('extras.run_script'): if not request.user.has_perm('extras.run_script'):
raise PermissionDenied("This user does not have permission to run scripts.") raise PermissionDenied("This user does not have permission to run scripts.")
script = self._get_script(pk)() module, script = self._get_script(pk)
input_serializer = serializers.ScriptInputSerializer(data=request.data) input_serializer = serializers.ScriptInputSerializer(data=request.data)
# Check that at least one RQ worker is running # Check that at least one RQ worker is running
@ -333,10 +336,10 @@ class ScriptViewSet(ViewSet):
raise RQWorkerNotRunningException() raise RQWorkerNotRunningException()
if input_serializer.is_valid(): if input_serializer.is_valid():
script.result = Job.enqueue_job( script.result = Job.enqueue(
run_script, run_script,
name=script.full_name, instance=module,
obj_type=ContentType.objects.get_for_model(Script), name=script.class_name,
user=request.user, user=request.user,
data=input_serializer.data['data'], data=input_serializer.data['data'],
request=copy_safe_request(request), request=copy_safe_request(request),

View File

@ -22,6 +22,12 @@ def get_report(module_name, report_name):
return module.reports.get(report_name) return module.reports.get(report_name)
def get_module_and_report(module_name, report_name):
module = ReportModule.objects.get(file_path=f'{module_name}.py')
report = module.reports.get(report_name)
return module, report
@job('default') @job('default')
def run_report(job_result, *args, **kwargs): def run_report(job_result, *args, **kwargs):
""" """

View File

@ -520,3 +520,9 @@ def get_script(module_name, script_name):
""" """
module = ScriptModule.objects.get(file_path=f'{module_name}.py') module = ScriptModule.objects.get(file_path=f'{module_name}.py')
return module.scripts.get(script_name) return module.scripts.get(script_name)
def get_module_and_script(module_name, script_name):
module = ScriptModule.objects.get(file_path=f'{module_name}.py')
script = module.scripts.get(script_name)
return module, script

View File

@ -9,6 +9,7 @@ from django_rq.queues import get_connection
from rest_framework import status from rest_framework import status
from rq import Worker from rq import Worker
from core.choices import ManagedFileRootPathChoices
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site
from extras.api.views import ReportViewSet, ScriptViewSet from extras.api.views import ReportViewSet, ScriptViewSet
from extras.models import * from extras.models import *
@ -524,14 +525,21 @@ class ReportTest(APITestCase):
def test_foo(self): def test_foo(self):
self.log_success(None, "Report completed") self.log_success(None, "Report completed")
@classmethod
def setUpTestData(cls):
ReportModule.objects.create(
file_root=ManagedFileRootPathChoices.REPORTS,
file_path='/var/tmp/report.py'
)
def get_test_report(self, *args): def get_test_report(self, *args):
return self.TestReport() return ReportModule.objects.first(), self.TestReport()
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# Monkey-patch the API viewset's _get_script method to return our test script above # Monkey-patch the API viewset's _get_report() method to return our test Report above
ReportViewSet._retrieve_report = self.get_test_report ReportViewSet._get_report = self.get_test_report
def test_get_report(self): def test_get_report(self):
url = reverse('extras-api:report-detail', kwargs={'pk': None}) url = reverse('extras-api:report-detail', kwargs={'pk': None})
@ -569,14 +577,20 @@ class ScriptTest(APITestCase):
return 'Script complete' return 'Script complete'
@classmethod
def setUpTestData(cls):
ScriptModule.objects.create(
file_root=ManagedFileRootPathChoices.SCRIPTS,
file_path='/var/tmp/script.py'
)
def get_test_script(self, *args): def get_test_script(self, *args):
return self.TestScript return ScriptModule.objects.first(), self.TestScript
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# Monkey-patch the API viewset's _get_script method to return our test script above # Monkey-patch the API viewset's _get_script() method to return our test Script above
ScriptViewSet._get_script = self.get_test_script ScriptViewSet._get_script = self.get_test_script
def test_get_script(self): def test_get_script(self):