mirror of
https://github.com/netbox-community/netbox.git
synced 2025-07-17 04:32:51 -06:00
Merge pull request #3654 from netbox-community/3538-scripts-api
3538: Add custom script API endpoints
This commit is contained in:
commit
56a248e601
@ -200,6 +200,52 @@ class ReportDetailSerializer(ReportSerializer):
|
|||||||
result = ReportResultSerializer()
|
result = ReportResultSerializer()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Scripts
|
||||||
|
#
|
||||||
|
|
||||||
|
class ScriptSerializer(serializers.Serializer):
|
||||||
|
id = serializers.SerializerMethodField(read_only=True)
|
||||||
|
name = serializers.SerializerMethodField(read_only=True)
|
||||||
|
description = serializers.SerializerMethodField(read_only=True)
|
||||||
|
vars = serializers.SerializerMethodField(read_only=True)
|
||||||
|
|
||||||
|
def get_id(self, instance):
|
||||||
|
return '{}.{}'.format(instance.__module__, instance.__name__)
|
||||||
|
|
||||||
|
def get_name(self, instance):
|
||||||
|
return getattr(instance.Meta, 'name', instance.__name__)
|
||||||
|
|
||||||
|
def get_description(self, instance):
|
||||||
|
return getattr(instance.Meta, 'description', '')
|
||||||
|
|
||||||
|
def get_vars(self, instance):
|
||||||
|
return {
|
||||||
|
k: v.__class__.__name__ for k, v in instance._get_vars().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptInputSerializer(serializers.Serializer):
|
||||||
|
data = serializers.JSONField()
|
||||||
|
commit = serializers.BooleanField()
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptLogMessageSerializer(serializers.Serializer):
|
||||||
|
status = serializers.SerializerMethodField(read_only=True)
|
||||||
|
message = serializers.SerializerMethodField(read_only=True)
|
||||||
|
|
||||||
|
def get_status(self, instance):
|
||||||
|
return LOG_LEVEL_CODES.get(instance[0])
|
||||||
|
|
||||||
|
def get_message(self, instance):
|
||||||
|
return instance[1]
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptOutputSerializer(serializers.Serializer):
|
||||||
|
log = ScriptLogMessageSerializer(many=True, read_only=True)
|
||||||
|
output = serializers.CharField(read_only=True)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Change logging
|
# Change logging
|
||||||
#
|
#
|
||||||
|
@ -38,6 +38,9 @@ router.register(r'config-contexts', views.ConfigContextViewSet)
|
|||||||
# Reports
|
# Reports
|
||||||
router.register(r'reports', views.ReportViewSet, basename='report')
|
router.register(r'reports', views.ReportViewSet, basename='report')
|
||||||
|
|
||||||
|
# Scripts
|
||||||
|
router.register(r'scripts', views.ScriptViewSet, basename='script')
|
||||||
|
|
||||||
# Change logging
|
# Change logging
|
||||||
router.register(r'object-changes', views.ObjectChangeViewSet)
|
router.register(r'object-changes', views.ObjectChangeViewSet)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from collections import OrderedDict
|
|||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.db.models import Count
|
from django.db.models import Count
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
|
from rest_framework import status
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.exceptions import PermissionDenied
|
from rest_framework.exceptions import PermissionDenied
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
@ -13,6 +14,7 @@ from extras.models import (
|
|||||||
ConfigContext, CustomFieldChoice, ExportTemplate, Graph, ImageAttachment, ObjectChange, ReportResult, Tag,
|
ConfigContext, CustomFieldChoice, ExportTemplate, Graph, ImageAttachment, ObjectChange, ReportResult, Tag,
|
||||||
)
|
)
|
||||||
from extras.reports import get_report, get_reports
|
from extras.reports import get_report, get_reports
|
||||||
|
from extras.scripts import get_script, get_scripts
|
||||||
from utilities.api import FieldChoicesViewSet, IsAuthenticatedOrLoginNotRequired, ModelViewSet
|
from utilities.api import FieldChoicesViewSet, IsAuthenticatedOrLoginNotRequired, ModelViewSet
|
||||||
from . import serializers
|
from . import serializers
|
||||||
|
|
||||||
@ -222,6 +224,56 @@ class ReportViewSet(ViewSet):
|
|||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Scripts
|
||||||
|
#
|
||||||
|
|
||||||
|
class ScriptViewSet(ViewSet):
|
||||||
|
permission_classes = [IsAuthenticatedOrLoginNotRequired]
|
||||||
|
_ignore_model_permissions = True
|
||||||
|
exclude_from_schema = True
|
||||||
|
lookup_value_regex = '[^/]+' # Allow dots
|
||||||
|
|
||||||
|
def _get_script(self, pk):
|
||||||
|
module_name, script_name = pk.split('.')
|
||||||
|
script = get_script(module_name, script_name)
|
||||||
|
if script is None:
|
||||||
|
raise Http404
|
||||||
|
return script
|
||||||
|
|
||||||
|
def list(self, request):
|
||||||
|
|
||||||
|
flat_list = []
|
||||||
|
for script_list in get_scripts().values():
|
||||||
|
flat_list.extend(script_list.values())
|
||||||
|
|
||||||
|
serializer = serializers.ScriptSerializer(flat_list, many=True, context={'request': request})
|
||||||
|
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
def retrieve(self, request, pk):
|
||||||
|
script = self._get_script(pk)
|
||||||
|
serializer = serializers.ScriptSerializer(script, context={'request': request})
|
||||||
|
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
def post(self, request, pk):
|
||||||
|
"""
|
||||||
|
Run a Script identified as "<module>.<script>".
|
||||||
|
"""
|
||||||
|
script = self._get_script(pk)()
|
||||||
|
input_serializer = serializers.ScriptInputSerializer(data=request.data)
|
||||||
|
|
||||||
|
if input_serializer.is_valid():
|
||||||
|
output = script.run(input_serializer.data['data'])
|
||||||
|
script.output = output
|
||||||
|
output_serializer = serializers.ScriptOutputSerializer(script)
|
||||||
|
|
||||||
|
return Response(output_serializer.data)
|
||||||
|
|
||||||
|
return Response(input_serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Change logging
|
# Change logging
|
||||||
#
|
#
|
||||||
|
@ -220,16 +220,21 @@ class BaseScript:
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return getattr(self.Meta, 'name', self.__class__.__name__)
|
return getattr(self.Meta, 'name', self.__class__.__name__)
|
||||||
|
|
||||||
def _get_vars(self):
|
@classmethod
|
||||||
|
def module(cls):
|
||||||
|
return cls.__module__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_vars(cls):
|
||||||
vars = OrderedDict()
|
vars = OrderedDict()
|
||||||
|
|
||||||
# Infer order from Meta.field_order (Python 3.5 and lower)
|
# Infer order from Meta.field_order (Python 3.5 and lower)
|
||||||
field_order = getattr(self.Meta, 'field_order', [])
|
field_order = getattr(cls.Meta, 'field_order', [])
|
||||||
for name in field_order:
|
for name in field_order:
|
||||||
vars[name] = getattr(self, name)
|
vars[name] = getattr(cls, name)
|
||||||
|
|
||||||
# Default to order of declaration on class
|
# Default to order of declaration on class
|
||||||
for name, attr in self.__class__.__dict__.items():
|
for name, attr in cls.__dict__.items():
|
||||||
if name not in vars and issubclass(attr.__class__, ScriptVariable):
|
if name not in vars and issubclass(attr.__class__, ScriptVariable):
|
||||||
vars[name] = attr
|
vars[name] = attr
|
||||||
|
|
||||||
@ -360,14 +365,18 @@ def run_script(script, data, files, commit=True):
|
|||||||
return output, execution_time
|
return output, execution_time
|
||||||
|
|
||||||
|
|
||||||
def get_scripts():
|
def get_scripts(use_names=False):
|
||||||
|
"""
|
||||||
|
Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
|
||||||
|
defined name in place of the actual module name.
|
||||||
|
"""
|
||||||
scripts = OrderedDict()
|
scripts = OrderedDict()
|
||||||
|
|
||||||
# Iterate through all modules within the reports path. These are the user-created files in which reports are
|
# Iterate through all modules within the reports path. These are the user-created files in which reports are
|
||||||
# defined.
|
# defined.
|
||||||
for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
|
for importer, module_name, _ in pkgutil.iter_modules([settings.SCRIPTS_ROOT]):
|
||||||
module = importer.find_module(module_name).load_module(module_name)
|
module = importer.find_module(module_name).load_module(module_name)
|
||||||
if hasattr(module, 'name'):
|
if use_names and hasattr(module, 'name'):
|
||||||
module_name = module.name
|
module_name = module.name
|
||||||
module_scripts = OrderedDict()
|
module_scripts = OrderedDict()
|
||||||
for name, cls in inspect.getmembers(module, is_script):
|
for name, cls in inspect.getmembers(module, is_script):
|
||||||
@ -375,3 +384,13 @@ def get_scripts():
|
|||||||
scripts[module_name] = module_scripts
|
scripts[module_name] = module_scripts
|
||||||
|
|
||||||
return scripts
|
return scripts
|
||||||
|
|
||||||
|
|
||||||
|
def get_script(module_name, script_name):
|
||||||
|
"""
|
||||||
|
Retrieve a script class by module and name. Returns None if the script does not exist.
|
||||||
|
"""
|
||||||
|
scripts = get_scripts()
|
||||||
|
module = scripts.get(module_name)
|
||||||
|
if module:
|
||||||
|
return module.get(script_name)
|
||||||
|
@ -3,8 +3,10 @@ from django.urls import reverse
|
|||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
|
||||||
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Platform, Region, Site
|
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Platform, Region, Site
|
||||||
|
from extras.api.views import ScriptViewSet
|
||||||
from extras.constants import GRAPH_TYPE_SITE
|
from extras.constants import GRAPH_TYPE_SITE
|
||||||
from extras.models import ConfigContext, Graph, ExportTemplate, Tag
|
from extras.models import ConfigContext, Graph, ExportTemplate, Tag
|
||||||
|
from extras.scripts import BooleanVar, IntegerVar, Script, StringVar
|
||||||
from tenancy.models import Tenant, TenantGroup
|
from tenancy.models import Tenant, TenantGroup
|
||||||
from utilities.testing import APITestCase
|
from utilities.testing import APITestCase
|
||||||
|
|
||||||
@ -520,3 +522,68 @@ class ConfigContextTest(APITestCase):
|
|||||||
configcontext6.sites.add(site2)
|
configcontext6.sites.add(site2)
|
||||||
rendered_context = device.get_config_context()
|
rendered_context = device.get_config_context()
|
||||||
self.assertEqual(rendered_context['bar'], 456)
|
self.assertEqual(rendered_context['bar'], 456)
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptTest(APITestCase):
|
||||||
|
|
||||||
|
class TestScript(Script):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
name = "Test script"
|
||||||
|
|
||||||
|
var1 = StringVar()
|
||||||
|
var2 = IntegerVar()
|
||||||
|
var3 = BooleanVar()
|
||||||
|
|
||||||
|
def run(self, data):
|
||||||
|
|
||||||
|
self.log_info(data['var1'])
|
||||||
|
self.log_success(data['var2'])
|
||||||
|
self.log_failure(data['var3'])
|
||||||
|
|
||||||
|
return 'Script complete'
|
||||||
|
|
||||||
|
def get_test_script(self, *args):
|
||||||
|
return self.TestScript
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# Monkey-patch the API viewset's _get_script method to return our test script above
|
||||||
|
ScriptViewSet._get_script = self.get_test_script
|
||||||
|
|
||||||
|
def test_get_script(self):
|
||||||
|
|
||||||
|
url = reverse('extras-api:script-detail', kwargs={'pk': None})
|
||||||
|
response = self.client.get(url, **self.header)
|
||||||
|
|
||||||
|
self.assertEqual(response.data['name'], self.TestScript.Meta.name)
|
||||||
|
self.assertEqual(response.data['vars']['var1'], 'StringVar')
|
||||||
|
self.assertEqual(response.data['vars']['var2'], 'IntegerVar')
|
||||||
|
self.assertEqual(response.data['vars']['var3'], 'BooleanVar')
|
||||||
|
|
||||||
|
def test_run_script(self):
|
||||||
|
|
||||||
|
script_data = {
|
||||||
|
'var1': 'FooBar',
|
||||||
|
'var2': 123,
|
||||||
|
'var3': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'data': script_data,
|
||||||
|
'commit': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
url = reverse('extras-api:script-detail', kwargs={'pk': None})
|
||||||
|
response = self.client.post(url, data, format='json', **self.header)
|
||||||
|
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
self.assertEqual(response.data['log'][0]['status'], 'info')
|
||||||
|
self.assertEqual(response.data['log'][0]['message'], script_data['var1'])
|
||||||
|
self.assertEqual(response.data['log'][1]['status'], 'success')
|
||||||
|
self.assertEqual(response.data['log'][1]['message'], script_data['var2'])
|
||||||
|
self.assertEqual(response.data['log'][2]['status'], 'failure')
|
||||||
|
self.assertEqual(response.data['log'][2]['message'], script_data['var3'])
|
||||||
|
self.assertEqual(response.data['output'], 'Script complete')
|
||||||
|
@ -375,7 +375,7 @@ class ScriptListView(PermissionRequiredMixin, View):
|
|||||||
def get(self, request):
|
def get(self, request):
|
||||||
|
|
||||||
return render(request, 'extras/script_list.html', {
|
return render(request, 'extras/script_list.html', {
|
||||||
'scripts': get_scripts(),
|
'scripts': get_scripts(use_names=True),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
{% for class_name, script in module_scripts.items %}
|
{% for class_name, script in module_scripts.items %}
|
||||||
<tr>
|
<tr>
|
||||||
<td>
|
<td>
|
||||||
<a href="{% url 'extras:script' module=module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
|
<a href="{% url 'extras:script' module=script.module name=class_name %}" name="script.{{ class_name }}"><strong>{{ script }}</strong></a>
|
||||||
</td>
|
</td>
|
||||||
<td>{{ script.Meta.description }}</td>
|
<td>{{ script.Meta.description }}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
Loading…
Reference in New Issue
Block a user