Consolidate get_scripts() and get_reports() functions

This commit is contained in:
jeremystretch 2023-03-23 15:27:52 -04:00
parent ba7fab3a5d
commit ab4a38d32d
4 changed files with 64 additions and 94 deletions

View File

@ -1,16 +1,13 @@
import inspect
import logging
import pkgutil
import traceback
from datetime import timedelta
from django.conf import settings
from django.utils import timezone
from django_rq import job
from core.models import ManagedFile
from .choices import JobResultStatusChoices, LogLevelChoices
from .models import JobResult
from .models import JobResult, ReportModule
from .utils import get_modules
logger = logging.getLogger(__name__)
@ -22,6 +19,10 @@ def is_report(obj):
return obj in Report.__subclasses__()
def get_reports():
return get_modules(ReportModule.objects.all(), is_report, 'report_order')
def get_report(module_name, report_name):
"""
Return a specific report from within a module.
@ -40,41 +41,6 @@ def get_report(module_name, report_name):
return report
def get_reports():
"""
Compile a list of all reports available across all modules in the reports path. Returns a list of tuples:
[
(module_name, (report, report, report, ...)),
(module_name, (report, report, report, ...)),
...
]
"""
module_list = {}
# Iterate through all modules within the reports path. These are the user-created files in which reports are
# defined.
# modules = pkgutil.iter_modules([settings.REPORTS_ROOT])
modules = [mf.get_module_info() for mf in ManagedFile.objects.filter(file_root='reports')]
for importer, module_name, _ in modules:
module = importer.find_module(module_name).load_module(module_name)
report_order = getattr(module, "report_order", ())
ordered_reports = [cls() for cls in report_order if is_report(cls)]
unordered_reports = [cls() for _, cls in inspect.getmembers(module, is_report) if cls not in report_order]
module_reports = {}
for cls in [*ordered_reports, *unordered_reports]:
# For reports in submodules use the full import path w/o the root module as the name
report_name = cls.full_name.split(".", maxsplit=1)[1]
module_reports[report_name] = cls
if module_reports:
module_list[module_name] = module_reports
return module_list
@job('default')
def run_report(job_result, *args, **kwargs):
"""

View File

@ -2,9 +2,6 @@ import inspect
import json
import logging
import os
import pkgutil
import sys
import threading
import traceback
from datetime import timedelta
@ -15,10 +12,9 @@ from django.core.validators import RegexValidator
from django.db import transaction
from django.utils.functional import classproperty
from core.models import ManagedFile
from extras.api.serializers import ScriptOutputSerializer
from extras.choices import JobResultStatusChoices, LogLevelChoices
from extras.models import JobResult
from extras.models import JobResult, ScriptModule
from extras.signals import clear_webhooks
from ipam.formfields import IPAddressFormField, IPNetworkFormField
from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator
@ -26,6 +22,7 @@ from utilities.exceptions import AbortScript, AbortTransaction
from utilities.forms import add_blank_choice, DynamicModelChoiceField, DynamicModelMultipleChoiceField
from .context_managers import change_logging
from .forms import ScriptForm
from .utils import get_modules
__all__ = [
'BaseScript',
@ -44,8 +41,6 @@ __all__ = [
'TextVar',
]
lock = threading.Lock()
#
# Script variables
@ -445,6 +440,10 @@ def is_variable(obj):
return isinstance(obj, ScriptVariable)
def get_scripts():
return get_modules(ScriptModule.objects.all(), is_script, 'script_order')
def run_script(data, request, commit=True, *args, **kwargs):
"""
A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It
@ -523,52 +522,6 @@ def run_script(data, request, commit=True, *args, **kwargs):
)
def get_scripts(use_names=False):
"""
Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
defined name in place of the actual module name.
"""
scripts = {}
# Get all modules within the scripts path. These are the user-created files in which scripts are
# defined.
# modules = list(pkgutil.iter_modules([settings.SCRIPTS_ROOT]))
modules = [mf.get_module_info() for mf in ManagedFile.objects.filter(file_root='scripts')]
modules_bases = set([name.split(".")[0] for _, name, _ in modules])
# Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is
# removed from sys.modules while another thread is importing
with lock:
for module_name in list(sys.modules.keys()):
# Everything sharing a base module path with a module in the script folder is removed.
# We also remove all modules with a base module called "scripts". This allows modifying imported
# non-script modules without having to reload the RQ worker.
module_base = module_name.split(".")[0]
if module_base == "scripts" or module_base in modules_bases:
del sys.modules[module_name]
for importer, module_name, _ in modules:
module = importer.find_module(module_name).load_module(module_name)
if use_names and hasattr(module, 'name'):
module_name = module.name
module_scripts = {}
script_order = getattr(module, "script_order", ())
ordered_scripts = [cls for cls in script_order if is_script(cls)]
unordered_scripts = [cls for _, cls in inspect.getmembers(module, is_script) if cls not in script_order]
for cls in [*ordered_scripts, *unordered_scripts]:
# For scripts in submodules use the full import path w/o the root module as the name
script_name = cls.full_name.split(".", maxsplit=1)[1]
module_scripts[script_name] = cls
if module_scripts:
scripts[module_name] = module_scripts
return scripts
def get_script(module_name, script_name):
"""
Retrieve a script class by module and name. Returns None if the script does not exist.

View File

@ -1,9 +1,15 @@
import inspect
import sys
import threading
from django.db.models import Q
from django.utils.deconstruct import deconstructible
from taggit.managers import _TaggableManager
from netbox.registry import registry
lock = threading.Lock()
def is_taggable(obj):
"""
@ -66,3 +72,48 @@ def register_features(model, features):
raise KeyError(
f"{feature} is not a valid model feature! Valid keys are: {registry['model_features'].keys()}"
)
def get_modules(queryset, litmus_func, ordering_attr):
"""
Returns a list of tuples:
[
(module_name, (child, child, ...)),
(module_name, (child, child, ...)),
...
]
"""
results = {}
modules = [mf.get_module_info() for mf in queryset]
modules_bases = set([name.split(".")[0] for _, name, _ in modules])
# Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is
# removed from sys.modules while another thread is importing
with lock:
for module_name in list(sys.modules.keys()):
# Everything sharing a base module path with a module in the script folder is removed.
# We also remove all modules with a base module called "scripts". This allows modifying imported
# non-script modules without having to reload the RQ worker.
module_base = module_name.split(".")[0]
if module_base in ('reports', 'scripts', *modules_bases):
del sys.modules[module_name]
for importer, module_name, _ in modules:
module = importer.find_module(module_name).load_module(module_name)
child_order = getattr(module, ordering_attr, ())
ordered_children = [cls() for cls in child_order if litmus_func(cls)]
unordered_children = [cls() for _, cls in inspect.getmembers(module, litmus_func) if cls not in child_order]
children = {}
for cls in [*ordered_children, *unordered_children]:
# For child objects in submodules use the full import path w/o the root module as the name
child_name = cls.full_name.split(".", maxsplit=1)[1]
children[child_name] = cls
if children:
results[module_name] = children
return results

View File

@ -942,7 +942,7 @@ class ScriptListView(ContentTypePermissionRequiredMixin, View):
def get(self, request):
scripts = get_scripts(use_names=True)
scripts = get_scripts()
script_content_type = ContentType.objects.get(app_label='extras', model='script')
results = {
r.name: r