Merge pull request #3654 from netbox-community/3538-scripts-api

3538: Add custom script API endpoints
This commit is contained in:
Jeremy Stretch 2019-10-30 09:35:22 -04:00 committed by GitHub
commit 56a248e601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 195 additions and 8 deletions

View File

@ -200,6 +200,52 @@ class ReportDetailSerializer(ReportSerializer):
result = ReportResultSerializer()
#
# Scripts
#
class ScriptSerializer(serializers.Serializer):
id = serializers.SerializerMethodField(read_only=True)
name = serializers.SerializerMethodField(read_only=True)
description = serializers.SerializerMethodField(read_only=True)
vars = serializers.SerializerMethodField(read_only=True)
def get_id(self, instance):
return '{}.{}'.format(instance.__module__, instance.__name__)
def get_name(self, instance):
return getattr(instance.Meta, 'name', instance.__name__)
def get_description(self, instance):
return getattr(instance.Meta, 'description', '')
def get_vars(self, instance):
return {
k: v.__class__.__name__ for k, v in instance._get_vars().items()
}
class ScriptInputSerializer(serializers.Serializer):
data = serializers.JSONField()
commit = serializers.BooleanField()
class ScriptLogMessageSerializer(serializers.Serializer):
status = serializers.SerializerMethodField(read_only=True)
message = serializers.SerializerMethodField(read_only=True)
def get_status(self, instance):
return LOG_LEVEL_CODES.get(instance[0])
def get_message(self, instance):
return instance[1]
class ScriptOutputSerializer(serializers.Serializer):
log = ScriptLogMessageSerializer(many=True, read_only=True)
output = serializers.CharField(read_only=True)
#
# Change logging
#

View File

@ -38,6 +38,9 @@ router.register(r'config-contexts', views.ConfigContextViewSet)
# Reports
router.register(r'reports', views.ReportViewSet, basename='report')
# Scripts
router.register(r'scripts', views.ScriptViewSet, basename='script')
# Change logging
router.register(r'object-changes', views.ObjectChangeViewSet)

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from django.contrib.contenttypes.models import ContentType
from django.db.models import Count
from django.http import Http404
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied
from rest_framework.response import Response
@ -13,6 +14,7 @@ from extras.models import (
ConfigContext, CustomFieldChoice, ExportTemplate, Graph, ImageAttachment, ObjectChange, ReportResult, Tag,
)
from extras.reports import get_report, get_reports
from extras.scripts import get_script, get_scripts
from utilities.api import FieldChoicesViewSet, IsAuthenticatedOrLoginNotRequired, ModelViewSet
from . import serializers
@ -222,6 +224,56 @@ class ReportViewSet(ViewSet):
return Response(serializer.data)
#
# Scripts
#
class ScriptViewSet(ViewSet):
permission_classes = [IsAuthenticatedOrLoginNotRequired]
_ignore_model_permissions = True
exclude_from_schema = True
lookup_value_regex = '[^/]+' # Allow dots
def _get_script(self, pk):
module_name, script_name = pk.split('.')
script = get_script(module_name, script_name)
if script is None:
raise Http404
return script
def list(self, request):
flat_list = []
for script_list in get_scripts().values():
flat_list.extend(script_list.values())
serializer = serializers.ScriptSerializer(flat_list, many=True, context={'request': request})
return Response(serializer.data)
def retrieve(self, request, pk):
script = self._get_script(pk)
serializer = serializers.ScriptSerializer(script, context={'request': request})
return Response(serializer.data)
def post(self, request, pk):
"""
Run a Script identified as "<module>.<script>".
"""
script = self._get_script(pk)()
input_serializer = serializers.ScriptInputSerializer(data=request.data)
if input_serializer.is_valid():
output = script.run(input_serializer.data['data'])
script.output = output
output_serializer = serializers.ScriptOutputSerializer(script)
return Response(output_serializer.data)
return Response(input_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
#
# Change logging
#

View File

@ -220,16 +220,21 @@ class BaseScript:
def __str__(self):
return getattr(self.Meta, 'name', self.__class__.__name__)
def _get_vars(self):
@classmethod
def module(cls):
return cls.__module__
@classmethod
def _get_vars(cls):
vars = OrderedDict()
# Infer order from Meta.field_order (Python 3.5 and lower)
field_order = getattr(self.Meta, 'field_order', [])
field_order = getattr(cls.Meta, 'field_order', [])
for name in field_order:
vars[name] = getattr(self, name)
vars[name] = getattr(cls, name)
# Default to order of declaration on class
for name, attr in self.__class__.__dict__.items():
for name, attr in cls.__dict__.items():
if name not in vars and issubclass(attr.__class__, ScriptVariable):
vars[name] = attr
@ -360,14 +365,18 @@ def run_script(script, data, files, commit=True):
return output, execution_time
def get_scripts():
def get_scripts(use_names=False):
"""
Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
defined name in place of the actual module name.
"""
scripts = OrderedDict()
# Iterate through all modules within the reports path. These are the user-created files in which reports are
# defined.
for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
module = importer.find_module(module_name).load_module(module_name)
if hasattr(module, 'name'):
if use_names and hasattr(module, 'name'):
module_name = module.name
module_scripts = OrderedDict()
for name, cls in inspect.getmembers(module, is_script):
@ -375,3 +384,13 @@ def get_scripts():
scripts[module_name] = module_scripts
return scripts
def get_script(module_name, script_name):
"""
Retrieve a script class by module and name. Returns None if the script does not exist.
"""
scripts = get_scripts()
module = scripts.get(module_name)
if module:
return module.get(script_name)

View File

@ -3,8 +3,10 @@ from django.urls import reverse
from rest_framework import status
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Platform, Region, Site
from extras.api.views import ScriptViewSet
from extras.constants import GRAPH_TYPE_SITE
from extras.models import ConfigContext, Graph, ExportTemplate, Tag
from extras.scripts import BooleanVar, IntegerVar, Script, StringVar
from tenancy.models import Tenant, TenantGroup
from utilities.testing import APITestCase
@ -520,3 +522,68 @@ class ConfigContextTest(APITestCase):
configcontext6.sites.add(site2)
rendered_context = device.get_config_context()
self.assertEqual(rendered_context['bar'], 456)
class ScriptTest(APITestCase):
class TestScript(Script):
class Meta:
name = "Test script"
var1 = StringVar()
var2 = IntegerVar()
var3 = BooleanVar()
def run(self, data):
self.log_info(data['var1'])
self.log_success(data['var2'])
self.log_failure(data['var3'])
return 'Script complete'
def get_test_script(self, *args):
return self.TestScript
def setUp(self):
super().setUp()
# Monkey-patch the API viewset's _get_script method to return our test script above
ScriptViewSet._get_script = self.get_test_script
def test_get_script(self):
url = reverse('extras-api:script-detail', kwargs={'pk': None})
response = self.client.get(url, **self.header)
self.assertEqual(response.data['name'], self.TestScript.Meta.name)
self.assertEqual(response.data['vars']['var1'], 'StringVar')
self.assertEqual(response.data['vars']['var2'], 'IntegerVar')
self.assertEqual(response.data['vars']['var3'], 'BooleanVar')
def test_run_script(self):
script_data = {
'var1': 'FooBar',
'var2': 123,
'var3': False,
}
data = {
'data': script_data,
'commit': True,
}
url = reverse('extras-api:script-detail', kwargs={'pk': None})
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['log'][0]['status'], 'info')
self.assertEqual(response.data['log'][0]['message'], script_data['var1'])
self.assertEqual(response.data['log'][1]['status'], 'success')
self.assertEqual(response.data['log'][1]['message'], script_data['var2'])
self.assertEqual(response.data['log'][2]['status'], 'failure')
self.assertEqual(response.data['log'][2]['message'], script_data['var3'])
self.assertEqual(response.data['output'], 'Script complete')

View File

@ -375,7 +375,7 @@ class ScriptListView(PermissionRequiredMixin, View):
def get(self, request):
return render(request, 'extras/script_list.html', {
'scripts': get_scripts(),
'scripts': get_scripts(use_names=True),
})

View File

@ -19,7 +19,7 @@
{% for class_name, script in module_scripts.items %}
<tr>
<td>
<a href="{% url 'extras:script' module=module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
<a href="{% url 'extras:script' module=script.module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
</td>
<td>{{ script.Meta.description }}</td>
</tr>