Merge pull request #3654 from netbox-community/3538-scripts-api

3538: Add custom script API endpoints
This commit is contained in:
Jeremy Stretch 2019-10-30 09:35:22 -04:00 committed by GitHub
commit 56a248e601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 195 additions and 8 deletions

View File

@ -200,6 +200,52 @@ class ReportDetailSerializer(ReportSerializer):
result = ReportResultSerializer() result = ReportResultSerializer()
#
# Scripts
#
class ScriptSerializer(serializers.Serializer):
id = serializers.SerializerMethodField(read_only=True)
name = serializers.SerializerMethodField(read_only=True)
description = serializers.SerializerMethodField(read_only=True)
vars = serializers.SerializerMethodField(read_only=True)
def get_id(self, instance):
return '{}.{}'.format(instance.__module__, instance.__name__)
def get_name(self, instance):
return getattr(instance.Meta, 'name', instance.__name__)
def get_description(self, instance):
return getattr(instance.Meta, 'description', '')
def get_vars(self, instance):
return {
k: v.__class__.__name__ for k, v in instance._get_vars().items()
}
class ScriptInputSerializer(serializers.Serializer):
data = serializers.JSONField()
commit = serializers.BooleanField()
class ScriptLogMessageSerializer(serializers.Serializer):
status = serializers.SerializerMethodField(read_only=True)
message = serializers.SerializerMethodField(read_only=True)
def get_status(self, instance):
return LOG_LEVEL_CODES.get(instance[0])
def get_message(self, instance):
return instance[1]
class ScriptOutputSerializer(serializers.Serializer):
log = ScriptLogMessageSerializer(many=True, read_only=True)
output = serializers.CharField(read_only=True)
# #
# Change logging # Change logging
# #

View File

@ -38,6 +38,9 @@ router.register(r'config-contexts', views.ConfigContextViewSet)
# Reports # Reports
router.register(r'reports', views.ReportViewSet, basename='report') router.register(r'reports', views.ReportViewSet, basename='report')
# Scripts
router.register(r'scripts', views.ScriptViewSet, basename='script')
# Change logging # Change logging
router.register(r'object-changes', views.ObjectChangeViewSet) router.register(r'object-changes', views.ObjectChangeViewSet)

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db.models import Count from django.db.models import Count
from django.http import Http404 from django.http import Http404
from rest_framework import status
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied from rest_framework.exceptions import PermissionDenied
from rest_framework.response import Response from rest_framework.response import Response
@ -13,6 +14,7 @@ from extras.models import (
ConfigContext, CustomFieldChoice, ExportTemplate, Graph, ImageAttachment, ObjectChange, ReportResult, Tag, ConfigContext, CustomFieldChoice, ExportTemplate, Graph, ImageAttachment, ObjectChange, ReportResult, Tag,
) )
from extras.reports import get_report, get_reports from extras.reports import get_report, get_reports
from extras.scripts import get_script, get_scripts
from utilities.api import FieldChoicesViewSet, IsAuthenticatedOrLoginNotRequired, ModelViewSet from utilities.api import FieldChoicesViewSet, IsAuthenticatedOrLoginNotRequired, ModelViewSet
from . import serializers from . import serializers
@ -222,6 +224,56 @@ class ReportViewSet(ViewSet):
return Response(serializer.data) return Response(serializer.data)
#
# Scripts
#
class ScriptViewSet(ViewSet):
permission_classes = [IsAuthenticatedOrLoginNotRequired]
_ignore_model_permissions = True
exclude_from_schema = True
lookup_value_regex = '[^/]+' # Allow dots
def _get_script(self, pk):
module_name, script_name = pk.split('.')
script = get_script(module_name, script_name)
if script is None:
raise Http404
return script
def list(self, request):
flat_list = []
for script_list in get_scripts().values():
flat_list.extend(script_list.values())
serializer = serializers.ScriptSerializer(flat_list, many=True, context={'request': request})
return Response(serializer.data)
def retrieve(self, request, pk):
script = self._get_script(pk)
serializer = serializers.ScriptSerializer(script, context={'request': request})
return Response(serializer.data)
def post(self, request, pk):
"""
Run a Script identified as "<module>.<script>".
"""
script = self._get_script(pk)()
input_serializer = serializers.ScriptInputSerializer(data=request.data)
if input_serializer.is_valid():
output = script.run(input_serializer.data['data'])
script.output = output
output_serializer = serializers.ScriptOutputSerializer(script)
return Response(output_serializer.data)
return Response(input_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
# #
# Change logging # Change logging
# #

View File

@ -220,16 +220,21 @@ class BaseScript:
def __str__(self): def __str__(self):
return getattr(self.Meta, 'name', self.__class__.__name__) return getattr(self.Meta, 'name', self.__class__.__name__)
def _get_vars(self): @classmethod
def module(cls):
return cls.__module__
@classmethod
def _get_vars(cls):
vars = OrderedDict() vars = OrderedDict()
# Infer order from Meta.field_order (Python 3.5 and lower) # Infer order from Meta.field_order (Python 3.5 and lower)
field_order = getattr(self.Meta, 'field_order', []) field_order = getattr(cls.Meta, 'field_order', [])
for name in field_order: for name in field_order:
vars[name] = getattr(self, name) vars[name] = getattr(cls, name)
# Default to order of declaration on class # Default to order of declaration on class
for name, attr in self.__class__.__dict__.items(): for name, attr in cls.__dict__.items():
if name not in vars and issubclass(attr.__class__, ScriptVariable): if name not in vars and issubclass(attr.__class__, ScriptVariable):
vars[name] = attr vars[name] = attr
@ -360,14 +365,18 @@ def run_script(script, data, files, commit=True):
return output, execution_time return output, execution_time
def get_scripts(): def get_scripts(use_names=False):
"""
Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
defined name in place of the actual module name.
"""
scripts = OrderedDict() scripts = OrderedDict()
# Iterate through all modules within the reports path. These are the user-created files in which reports are # Iterate through all modules within the reports path. These are the user-created files in which reports are
# defined. # defined.
for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]): for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
module = importer.find_module(module_name).load_module(module_name) module = importer.find_module(module_name).load_module(module_name)
if hasattr(module, 'name'): if use_names and hasattr(module, 'name'):
module_name = module.name module_name = module.name
module_scripts = OrderedDict() module_scripts = OrderedDict()
for name, cls in inspect.getmembers(module, is_script): for name, cls in inspect.getmembers(module, is_script):
@ -375,3 +384,13 @@ def get_scripts():
scripts[module_name] = module_scripts scripts[module_name] = module_scripts
return scripts return scripts
def get_script(module_name, script_name):
"""
Retrieve a script class by module and name. Returns None if the script does not exist.
"""
scripts = get_scripts()
module = scripts.get(module_name)
if module:
return module.get(script_name)

View File

@ -3,8 +3,10 @@ from django.urls import reverse
from rest_framework import status from rest_framework import status
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Platform, Region, Site from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Platform, Region, Site
from extras.api.views import ScriptViewSet
from extras.constants import GRAPH_TYPE_SITE from extras.constants import GRAPH_TYPE_SITE
from extras.models import ConfigContext, Graph, ExportTemplate, Tag from extras.models import ConfigContext, Graph, ExportTemplate, Tag
from extras.scripts import BooleanVar, IntegerVar, Script, StringVar
from tenancy.models import Tenant, TenantGroup from tenancy.models import Tenant, TenantGroup
from utilities.testing import APITestCase from utilities.testing import APITestCase
@ -520,3 +522,68 @@ class ConfigContextTest(APITestCase):
configcontext6.sites.add(site2) configcontext6.sites.add(site2)
rendered_context = device.get_config_context() rendered_context = device.get_config_context()
self.assertEqual(rendered_context['bar'], 456) self.assertEqual(rendered_context['bar'], 456)
class ScriptTest(APITestCase):
class TestScript(Script):
class Meta:
name = "Test script"
var1 = StringVar()
var2 = IntegerVar()
var3 = BooleanVar()
def run(self, data):
self.log_info(data['var1'])
self.log_success(data['var2'])
self.log_failure(data['var3'])
return 'Script complete'
def get_test_script(self, *args):
return self.TestScript
def setUp(self):
super().setUp()
# Monkey-patch the API viewset's _get_script method to return our test script above
ScriptViewSet._get_script = self.get_test_script
def test_get_script(self):
url = reverse('extras-api:script-detail', kwargs={'pk': None})
response = self.client.get(url, **self.header)
self.assertEqual(response.data['name'], self.TestScript.Meta.name)
self.assertEqual(response.data['vars']['var1'], 'StringVar')
self.assertEqual(response.data['vars']['var2'], 'IntegerVar')
self.assertEqual(response.data['vars']['var3'], 'BooleanVar')
def test_run_script(self):
script_data = {
'var1': 'FooBar',
'var2': 123,
'var3': False,
}
data = {
'data': script_data,
'commit': True,
}
url = reverse('extras-api:script-detail', kwargs={'pk': None})
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['log'][0]['status'], 'info')
self.assertEqual(response.data['log'][0]['message'], script_data['var1'])
self.assertEqual(response.data['log'][1]['status'], 'success')
self.assertEqual(response.data['log'][1]['message'], script_data['var2'])
self.assertEqual(response.data['log'][2]['status'], 'failure')
self.assertEqual(response.data['log'][2]['message'], script_data['var3'])
self.assertEqual(response.data['output'], 'Script complete')

View File

@ -375,7 +375,7 @@ class ScriptListView(PermissionRequiredMixin, View):
def get(self, request): def get(self, request):
return render(request, 'extras/script_list.html', { return render(request, 'extras/script_list.html', {
'scripts': get_scripts(), 'scripts': get_scripts(use_names=True),
}) })

View File

@ -19,7 +19,7 @@
{% for class_name, script in module_scripts.items %} {% for class_name, script in module_scripts.items %}
<tr> <tr>
<td> <td>
<a href="{% url 'extras:script' module=module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a> <a href="{% url 'extras:script' module=script.module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
</td> </td>
<td>{{ script.Meta.description }}</td> <td>{{ script.Meta.description }}</td>
</tr> </tr>