mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-26 17:26:10 -06:00
Reference database object by GFK when running scripts & reports via API
This commit is contained in:
parent
ccb09b0f7b
commit
69fd138533
@ -17,8 +17,8 @@ from core.choices import JobStatusChoices
|
|||||||
from core.models import Job
|
from core.models import Job
|
||||||
from extras import filtersets
|
from extras import filtersets
|
||||||
from extras.models import *
|
from extras.models import *
|
||||||
from extras.reports import get_report, run_report
|
from extras.reports import get_module_and_report, get_report, run_report
|
||||||
from extras.scripts import get_script, run_script
|
from extras.scripts import get_module_and_script, get_script, run_script
|
||||||
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
|
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
|
||||||
from netbox.api.features import SyncedDataMixin
|
from netbox.api.features import SyncedDataMixin
|
||||||
from netbox.api.metadata import ContentTypeMetadata
|
from netbox.api.metadata import ContentTypeMetadata
|
||||||
@ -171,19 +171,17 @@ class ReportViewSet(ViewSet):
|
|||||||
exclude_from_schema = True
|
exclude_from_schema = True
|
||||||
lookup_value_regex = '[^/]+' # Allow dots
|
lookup_value_regex = '[^/]+' # Allow dots
|
||||||
|
|
||||||
def _retrieve_report(self, pk):
|
def _get_report(self, pk):
|
||||||
|
try:
|
||||||
# Read the PK as "<module>.<report>"
|
module_name, report_name = pk.split('.', maxsplit=1)
|
||||||
if '.' not in pk:
|
except ValueError:
|
||||||
raise Http404
|
raise Http404
|
||||||
module_name, report_name = pk.split('.', maxsplit=1)
|
|
||||||
|
|
||||||
# Raise a 404 on an invalid Report module/name
|
module, report = get_module_and_report(module_name, report_name)
|
||||||
report = get_report(module_name, report_name)
|
|
||||||
if report is None:
|
if report is None:
|
||||||
raise Http404
|
raise Http404
|
||||||
|
|
||||||
return report
|
return module, report
|
||||||
|
|
||||||
def list(self, request):
|
def list(self, request):
|
||||||
"""
|
"""
|
||||||
@ -216,13 +214,13 @@ class ReportViewSet(ViewSet):
|
|||||||
"""
|
"""
|
||||||
Retrieve a single Report identified as "<module>.<report>".
|
Retrieve a single Report identified as "<module>.<report>".
|
||||||
"""
|
"""
|
||||||
|
module, report = self._get_report(pk)
|
||||||
|
|
||||||
# Retrieve the Report and Job, if any.
|
# Retrieve the Report and Job, if any.
|
||||||
report = self._retrieve_report(pk)
|
object_type = ContentType.objects.get(app_label='extras', model='reportmodule')
|
||||||
report_content_type = ContentType.objects.get(app_label='extras', model='report')
|
|
||||||
report.result = Job.objects.filter(
|
report.result = Job.objects.filter(
|
||||||
object_type=report_content_type,
|
object_type=object_type,
|
||||||
name=report.full_name,
|
name=report.name,
|
||||||
status__in=JobStatusChoices.TERMINAL_STATE_CHOICES
|
status__in=JobStatusChoices.TERMINAL_STATE_CHOICES
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
@ -246,14 +244,14 @@ class ReportViewSet(ViewSet):
|
|||||||
raise RQWorkerNotRunningException()
|
raise RQWorkerNotRunningException()
|
||||||
|
|
||||||
# Retrieve and run the Report. This will create a new Job.
|
# Retrieve and run the Report. This will create a new Job.
|
||||||
report = self._retrieve_report(pk)
|
module, report = self._get_report(pk)
|
||||||
input_serializer = serializers.ReportInputSerializer(data=request.data)
|
input_serializer = serializers.ReportInputSerializer(data=request.data)
|
||||||
|
|
||||||
if input_serializer.is_valid():
|
if input_serializer.is_valid():
|
||||||
report.result = Job.enqueue_job(
|
report.result = Job.enqueue(
|
||||||
run_report,
|
run_report,
|
||||||
name=report.full_name,
|
instance=module,
|
||||||
obj_type=ContentType.objects.get_for_model(Report),
|
name=report.class_name,
|
||||||
user=request.user,
|
user=request.user,
|
||||||
job_timeout=report.job_timeout,
|
job_timeout=report.job_timeout,
|
||||||
schedule_at=input_serializer.validated_data.get('schedule_at'),
|
schedule_at=input_serializer.validated_data.get('schedule_at'),
|
||||||
@ -276,11 +274,16 @@ class ScriptViewSet(ViewSet):
|
|||||||
lookup_value_regex = '[^/]+' # Allow dots
|
lookup_value_regex = '[^/]+' # Allow dots
|
||||||
|
|
||||||
def _get_script(self, pk):
|
def _get_script(self, pk):
|
||||||
module_name, script_name = pk.split('.', maxsplit=1)
|
try:
|
||||||
script = get_script(module_name, script_name)
|
module_name, script_name = pk.split('.', maxsplit=1)
|
||||||
|
except ValueError:
|
||||||
|
raise Http404
|
||||||
|
|
||||||
|
module, script = get_module_and_script(module_name, script_name)
|
||||||
if script is None:
|
if script is None:
|
||||||
raise Http404
|
raise Http404
|
||||||
return script
|
|
||||||
|
return module, script
|
||||||
|
|
||||||
def list(self, request):
|
def list(self, request):
|
||||||
|
|
||||||
@ -306,11 +309,11 @@ class ScriptViewSet(ViewSet):
|
|||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
def retrieve(self, request, pk):
|
def retrieve(self, request, pk):
|
||||||
script = self._get_script(pk)
|
module, script = self._get_script(pk)
|
||||||
script_content_type = ContentType.objects.get(app_label='extras', model='script')
|
object_type = ContentType.objects.get(app_label='extras', model='scriptmodule')
|
||||||
script.result = Job.objects.filter(
|
script.result = Job.objects.filter(
|
||||||
object_type=script_content_type,
|
object_type=object_type,
|
||||||
name=script.full_name,
|
name=script.name,
|
||||||
status__in=JobStatusChoices.TERMINAL_STATE_CHOICES
|
status__in=JobStatusChoices.TERMINAL_STATE_CHOICES
|
||||||
).first()
|
).first()
|
||||||
serializer = serializers.ScriptDetailSerializer(script, context={'request': request})
|
serializer = serializers.ScriptDetailSerializer(script, context={'request': request})
|
||||||
@ -325,7 +328,7 @@ class ScriptViewSet(ViewSet):
|
|||||||
if not request.user.has_perm('extras.run_script'):
|
if not request.user.has_perm('extras.run_script'):
|
||||||
raise PermissionDenied("This user does not have permission to run scripts.")
|
raise PermissionDenied("This user does not have permission to run scripts.")
|
||||||
|
|
||||||
script = self._get_script(pk)()
|
module, script = self._get_script(pk)
|
||||||
input_serializer = serializers.ScriptInputSerializer(data=request.data)
|
input_serializer = serializers.ScriptInputSerializer(data=request.data)
|
||||||
|
|
||||||
# Check that at least one RQ worker is running
|
# Check that at least one RQ worker is running
|
||||||
@ -333,10 +336,10 @@ class ScriptViewSet(ViewSet):
|
|||||||
raise RQWorkerNotRunningException()
|
raise RQWorkerNotRunningException()
|
||||||
|
|
||||||
if input_serializer.is_valid():
|
if input_serializer.is_valid():
|
||||||
script.result = Job.enqueue_job(
|
script.result = Job.enqueue(
|
||||||
run_script,
|
run_script,
|
||||||
name=script.full_name,
|
instance=module,
|
||||||
obj_type=ContentType.objects.get_for_model(Script),
|
name=script.class_name,
|
||||||
user=request.user,
|
user=request.user,
|
||||||
data=input_serializer.data['data'],
|
data=input_serializer.data['data'],
|
||||||
request=copy_safe_request(request),
|
request=copy_safe_request(request),
|
||||||
|
@ -22,6 +22,12 @@ def get_report(module_name, report_name):
|
|||||||
return module.reports.get(report_name)
|
return module.reports.get(report_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_and_report(module_name, report_name):
|
||||||
|
module = ReportModule.objects.get(file_path=f'{module_name}.py')
|
||||||
|
report = module.reports.get(report_name)
|
||||||
|
return module, report
|
||||||
|
|
||||||
|
|
||||||
@job('default')
|
@job('default')
|
||||||
def run_report(job_result, *args, **kwargs):
|
def run_report(job_result, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -520,3 +520,9 @@ def get_script(module_name, script_name):
|
|||||||
"""
|
"""
|
||||||
module = ScriptModule.objects.get(file_path=f'{module_name}.py')
|
module = ScriptModule.objects.get(file_path=f'{module_name}.py')
|
||||||
return module.scripts.get(script_name)
|
return module.scripts.get(script_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_and_script(module_name, script_name):
|
||||||
|
module = ScriptModule.objects.get(file_path=f'{module_name}.py')
|
||||||
|
script = module.scripts.get(script_name)
|
||||||
|
return module, script
|
||||||
|
@ -9,6 +9,7 @@ from django_rq.queues import get_connection
|
|||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rq import Worker
|
from rq import Worker
|
||||||
|
|
||||||
|
from core.choices import ManagedFileRootPathChoices
|
||||||
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site
|
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Rack, Location, RackRole, Site
|
||||||
from extras.api.views import ReportViewSet, ScriptViewSet
|
from extras.api.views import ReportViewSet, ScriptViewSet
|
||||||
from extras.models import *
|
from extras.models import *
|
||||||
@ -524,14 +525,21 @@ class ReportTest(APITestCase):
|
|||||||
def test_foo(self):
|
def test_foo(self):
|
||||||
self.log_success(None, "Report completed")
|
self.log_success(None, "Report completed")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
ReportModule.objects.create(
|
||||||
|
file_root=ManagedFileRootPathChoices.REPORTS,
|
||||||
|
file_path='/var/tmp/report.py'
|
||||||
|
)
|
||||||
|
|
||||||
def get_test_report(self, *args):
|
def get_test_report(self, *args):
|
||||||
return self.TestReport()
|
return ReportModule.objects.first(), self.TestReport()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
# Monkey-patch the API viewset's _get_script method to return our test script above
|
# Monkey-patch the API viewset's _get_report() method to return our test Report above
|
||||||
ReportViewSet._retrieve_report = self.get_test_report
|
ReportViewSet._get_report = self.get_test_report
|
||||||
|
|
||||||
def test_get_report(self):
|
def test_get_report(self):
|
||||||
url = reverse('extras-api:report-detail', kwargs={'pk': None})
|
url = reverse('extras-api:report-detail', kwargs={'pk': None})
|
||||||
@ -569,14 +577,20 @@ class ScriptTest(APITestCase):
|
|||||||
|
|
||||||
return 'Script complete'
|
return 'Script complete'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
ScriptModule.objects.create(
|
||||||
|
file_root=ManagedFileRootPathChoices.SCRIPTS,
|
||||||
|
file_path='/var/tmp/script.py'
|
||||||
|
)
|
||||||
|
|
||||||
def get_test_script(self, *args):
|
def get_test_script(self, *args):
|
||||||
return self.TestScript
|
return ScriptModule.objects.first(), self.TestScript
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
# Monkey-patch the API viewset's _get_script method to return our test script above
|
# Monkey-patch the API viewset's _get_script() method to return our test Script above
|
||||||
ScriptViewSet._get_script = self.get_test_script
|
ScriptViewSet._get_script = self.get_test_script
|
||||||
|
|
||||||
def test_get_script(self):
|
def test_get_script(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user