Simplify retrieval logic

This commit is contained in:
Jeremy Stretch 2024-05-22 09:02:47 -04:00
parent 1253d9188f
commit 1d5d283cb9

View File

@ -14,7 +14,7 @@ from rq import Worker
from core.models import Job, ObjectType from core.models import Job, ObjectType
from extras import filtersets from extras import filtersets
from extras.models import * from extras.models import *
from extras.scripts import get_module_and_script, run_script from extras.scripts import run_script
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.features import SyncedDataMixin from netbox.api.features import SyncedDataMixin
from netbox.api.metadata import ContentTypeMetadata from netbox.api.metadata import ContentTypeMetadata
@ -217,37 +217,31 @@ class ScriptViewSet(ModelViewSet):
lookup_value_regex = '[^/]+' # Allow dots lookup_value_regex = '[^/]+' # Allow dots
def _get_script(self, pk): def _get_script(self, pk):
# If pk is numeric, retrieve script by ID
if pk.isnumeric(): if pk.isnumeric():
script = get_object_or_404(self.queryset, pk=pk) return get_object_or_404(self.queryset, pk=pk)
module = script.module
else:
try:
module_name, script_name = pk.split('.', maxsplit=1)
except ValueError:
raise Http404
module, script = get_module_and_script(module_name, script_name) # Default to retrieval by module & name
try:
if script is None: module_name, script_name = pk.split('.', maxsplit=1)
except ValueError:
raise Http404 raise Http404
return get_object_or_404(self.queryset, module__file_path=f'{module_name}.py', name=script_name)
return module, script
def retrieve(self, request, pk): def retrieve(self, request, pk):
module, script = self._get_script(pk) script = self._get_script(pk)
serializer = serializers.ScriptDetailSerializer(script, context={'request': request}) serializer = serializers.ScriptDetailSerializer(script, context={'request': request})
return Response(serializer.data) return Response(serializer.data)
def post(self, request, pk): def post(self, request, pk):
""" """
Run a Script identified by the name or pk and return the pending Job as the result Run a Script identified by its numeric PK or module & name and return the pending Job as the result
""" """
if not request.user.has_perm('extras.run_script'): if not request.user.has_perm('extras.run_script'):
raise PermissionDenied("This user does not have permission to run scripts.") raise PermissionDenied("This user does not have permission to run scripts.")
module, script = self._get_script(pk) script = self._get_script(pk)
input_serializer = serializers.ScriptInputSerializer( input_serializer = serializers.ScriptInputSerializer(
data=request.data, data=request.data,
context={'script': script} context={'script': script}