diff --git a/netbox/extras/api/views.py b/netbox/extras/api/views.py index 6b8309b28..e0dc6b100 100644 --- a/netbox/extras/api/views.py +++ b/netbox/extras/api/views.py @@ -17,8 +17,8 @@ from core.choices import JobStatusChoices from core.models import Job from extras import filtersets from extras.models import * -from extras.reports import get_report, run_report -from extras.scripts import get_script, run_script +from extras.reports import get_module_and_report, get_report, run_report +from extras.scripts import get_module_and_script, get_script, run_script from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired from netbox.api.features import SyncedDataMixin from netbox.api.metadata import ContentTypeMetadata @@ -171,19 +171,17 @@ class ReportViewSet(ViewSet): exclude_from_schema = True lookup_value_regex = '[^/]+' # Allow dots - def _retrieve_report(self, pk): - - # Read the PK as "." - if '.' not in pk: + def _get_report(self, pk): + try: + module_name, report_name = pk.split('.', maxsplit=1) + except ValueError: raise Http404 - module_name, report_name = pk.split('.', maxsplit=1) - # Raise a 404 on an invalid Report module/name - report = get_report(module_name, report_name) + module, report = get_module_and_report(module_name, report_name) if report is None: raise Http404 - return report + return module, report def list(self, request): """ @@ -216,13 +214,13 @@ class ReportViewSet(ViewSet): """ Retrieve a single Report identified as ".". """ + module, report = self._get_report(pk) # Retrieve the Report and Job, if any. - report = self._retrieve_report(pk) - report_content_type = ContentType.objects.get(app_label='extras', model='report') + object_type = ContentType.objects.get(app_label='extras', model='reportmodule') report.result = Job.objects.filter( - object_type=report_content_type, - name=report.full_name, + object_type=object_type, + name=report.name, status__in=JobStatusChoices.TERMINAL_STATE_CHOICES ).first() @@ -246,14 +244,14 @@ class ReportViewSet(ViewSet): raise RQWorkerNotRunningException() # Retrieve and run the Report. This will create a new Job. - report = self._retrieve_report(pk) + module, report = self._get_report(pk) input_serializer = serializers.ReportInputSerializer(data=request.data) if input_serializer.is_valid(): - report.result = Job.enqueue_job( + report.result = Job.enqueue( run_report, - name=report.full_name, - obj_type=ContentType.objects.get_for_model(Report), + instance=module, + name=report.class_name, user=request.user, job_timeout=report.job_timeout, schedule_at=input_serializer.validated_data.get('schedule_at'), @@ -276,11 +274,16 @@ class ScriptViewSet(ViewSet): lookup_value_regex = '[^/]+' # Allow dots def _get_script(self, pk): - module_name, script_name = pk.split('.', maxsplit=1) - script = get_script(module_name, script_name) + try: + module_name, script_name = pk.split('.', maxsplit=1) + except ValueError: + raise Http404 + + module, script = get_module_and_script(module_name, script_name) if script is None: raise Http404 - return script + + return module, script def list(self, request): @@ -306,11 +309,11 @@ class ScriptViewSet(ViewSet): return Response(serializer.data) def retrieve(self, request, pk): - script = self._get_script(pk) - script_content_type = ContentType.objects.get(app_label='extras', model='script') + module, script = self._get_script(pk) + object_type = ContentType.objects.get(app_label='extras', model='scriptmodule') script.result = Job.objects.filter( - object_type=script_content_type, - name=script.full_name, + object_type=object_type, + name=script.name, status__in=JobStatusChoices.TERMINAL_STATE_CHOICES ).first() serializer = serializers.ScriptDetailSerializer(script, context={'request': request}) @@ -325,7 +328,7 @@ class ScriptViewSet(ViewSet): if not request.user.has_perm('extras.run_script'): raise PermissionDenied("This user does not have permission to run scripts.") - script = self._get_script(pk)() + module, script = self._get_script(pk) input_serializer = serializers.ScriptInputSerializer(data=request.data) # Check that at least one RQ worker is running @@ -333,10 +336,10 @@ class ScriptViewSet(ViewSet): raise RQWorkerNotRunningException() if input_serializer.is_valid(): - script.result = Job.enqueue_job( + script.result = Job.enqueue( run_script, - name=script.full_name, - obj_type=ContentType.objects.get_for_model(Script), + instance=module, + name=script.class_name, user=request.user, data=input_serializer.data['data'], request=copy_safe_request(request), diff --git a/netbox/extras/reports.py b/netbox/extras/reports.py index 086bd0977..4dd62ed29 100644 --- a/netbox/extras/reports.py +++ b/netbox/extras/reports.py @@ -22,6 +22,12 @@ def get_report(module_name, report_name): return module.reports.get(report_name) +def get_module_and_report(module_name, report_name): + module = ReportModule.objects.get(file_path=f'{module_name}.py') + report = module.reports.get(report_name) + return module, report + + @job('default') def run_report(job_result, *args, **kwargs): """ diff --git a/netbox/extras/scripts.py b/netbox/extras/scripts.py index 69fd88aaa..9526e1cf8 100644 --- a/netbox/extras/scripts.py +++ b/netbox/extras/scripts.py @@ -520,3 +520,9 @@ def get_script(module_name, script_name): """ module = ScriptModule.objects.get(file_path=f'{module_name}.py') return module.scripts.get(script_name) + + +def get_module_and_script(module_name, script_name): + module = ScriptModule.objects.get(file_path=f'{module_name}.py') + script = module.scripts.get(script_name) + return module, script diff --git a/netbox/extras/tests/test_api.py b/netbox/extras/tests/test_api.py index 81a607eec..aa293e318 100644 --- a/netbox/extras/tests/test_api.py +++ b/netbox/extras/tests/test_api.py @@ -9,6 +9,7 @@ from django_rq.queues import get_connection from rest_framework import status from rq import Worker +from core.choices import ManagedFileRootPathChoices from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site from extras.api.views import ReportViewSet, ScriptViewSet from extras.models import * @@ -524,14 +525,21 @@ class ReportTest(APITestCase): def test_foo(self): self.log_success(None, "Report completed") + @classmethod + def setUpTestData(cls): + ReportModule.objects.create( + file_root=ManagedFileRootPathChoices.REPORTS, + file_path='/var/tmp/report.py' + ) + def get_test_report(self, *args): - return self.TestReport() + return ReportModule.objects.first(), self.TestReport() def setUp(self): super().setUp() - # Monkey-patch the API viewset's _get_script method to return our test script above - ReportViewSet._retrieve_report = self.get_test_report + # Monkey-patch the API viewset's _get_report() method to return our test Report above + ReportViewSet._get_report = self.get_test_report def test_get_report(self): url = reverse('extras-api:report-detail', kwargs={'pk': None}) @@ -569,14 +577,20 @@ class ScriptTest(APITestCase): return 'Script complete' + @classmethod + def setUpTestData(cls): + ScriptModule.objects.create( + file_root=ManagedFileRootPathChoices.SCRIPTS, + file_path='/var/tmp/script.py' + ) + def get_test_script(self, *args): - return self.TestScript + return ScriptModule.objects.first(), self.TestScript def setUp(self): - super().setUp() - # Monkey-patch the API viewset's _get_script method to return our test script above + # Monkey-patch the API viewset's _get_script() method to return our test Script above ScriptViewSet._get_script = self.get_test_script def test_get_script(self):