diff --git a/netbox/extras/reports.py b/netbox/extras/reports.py index c476e2673..fafe04634 100644 --- a/netbox/extras/reports.py +++ b/netbox/extras/reports.py @@ -1,16 +1,13 @@ -import inspect import logging -import pkgutil import traceback from datetime import timedelta -from django.conf import settings from django.utils import timezone from django_rq import job -from core.models import ManagedFile from .choices import JobResultStatusChoices, LogLevelChoices -from .models import JobResult +from .models import JobResult, ReportModule +from .utils import get_modules logger = logging.getLogger(__name__) @@ -22,6 +19,10 @@ def is_report(obj): return obj in Report.__subclasses__() +def get_reports(): + return get_modules(ReportModule.objects.all(), is_report, 'report_order') + + def get_report(module_name, report_name): """ Return a specific report from within a module. @@ -40,41 +41,6 @@ def get_report(module_name, report_name): return report -def get_reports(): - """ - Compile a list of all reports available across all modules in the reports path. Returns a list of tuples: - - [ - (module_name, (report, report, report, ...)), - (module_name, (report, report, report, ...)), - ... - ] - """ - module_list = {} - - # Iterate through all modules within the reports path. These are the user-created files in which reports are - # defined. - # modules = pkgutil.iter_modules([settings.REPORTS_ROOT]) - modules = [mf.get_module_info() for mf in ManagedFile.objects.filter(file_root='reports')] - for importer, module_name, _ in modules: - module = importer.find_module(module_name).load_module(module_name) - report_order = getattr(module, "report_order", ()) - ordered_reports = [cls() for cls in report_order if is_report(cls)] - unordered_reports = [cls() for _, cls in inspect.getmembers(module, is_report) if cls not in report_order] - - module_reports = {} - - for cls in [*ordered_reports, *unordered_reports]: - # For reports in submodules use the full import path w/o the root module as the name - report_name = cls.full_name.split(".", maxsplit=1)[1] - module_reports[report_name] = cls - - if module_reports: - module_list[module_name] = module_reports - - return module_list - - @job('default') def run_report(job_result, *args, **kwargs): """ diff --git a/netbox/extras/scripts.py b/netbox/extras/scripts.py index fa7a76cb1..ad51280b6 100644 --- a/netbox/extras/scripts.py +++ b/netbox/extras/scripts.py @@ -2,9 +2,6 @@ import inspect import json import logging import os -import pkgutil -import sys -import threading import traceback from datetime import timedelta @@ -15,10 +12,9 @@ from django.core.validators import RegexValidator from django.db import transaction from django.utils.functional import classproperty -from core.models import ManagedFile from extras.api.serializers import ScriptOutputSerializer from extras.choices import JobResultStatusChoices, LogLevelChoices -from extras.models import JobResult +from extras.models import JobResult, ScriptModule from extras.signals import clear_webhooks from ipam.formfields import IPAddressFormField, IPNetworkFormField from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator @@ -26,6 +22,7 @@ from utilities.exceptions import AbortScript, AbortTransaction from utilities.forms import add_blank_choice, DynamicModelChoiceField, DynamicModelMultipleChoiceField from .context_managers import change_logging from .forms import ScriptForm +from .utils import get_modules __all__ = [ 'BaseScript', @@ -44,8 +41,6 @@ __all__ = [ 'TextVar', ] -lock = threading.Lock() - # # Script variables @@ -445,6 +440,10 @@ def is_variable(obj): return isinstance(obj, ScriptVariable) +def get_scripts(): + return get_modules(ScriptModule.objects.all(), is_script, 'script_order') + + def run_script(data, request, commit=True, *args, **kwargs): """ A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It @@ -523,52 +522,6 @@ def run_script(data, request, commit=True, *args, **kwargs): ) -def get_scripts(use_names=False): - """ - Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human- - defined name in place of the actual module name. - """ - scripts = {} - - # Get all modules within the scripts path. These are the user-created files in which scripts are - # defined. - # modules = list(pkgutil.iter_modules([settings.SCRIPTS_ROOT])) - modules = [mf.get_module_info() for mf in ManagedFile.objects.filter(file_root='scripts')] - modules_bases = set([name.split(".")[0] for _, name, _ in modules]) - - # Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is - # removed from sys.modules while another thread is importing - with lock: - for module_name in list(sys.modules.keys()): - # Everything sharing a base module path with a module in the script folder is removed. - # We also remove all modules with a base module called "scripts". This allows modifying imported - # non-script modules without having to reload the RQ worker. - module_base = module_name.split(".")[0] - if module_base == "scripts" or module_base in modules_bases: - del sys.modules[module_name] - - for importer, module_name, _ in modules: - module = importer.find_module(module_name).load_module(module_name) - - if use_names and hasattr(module, 'name'): - module_name = module.name - - module_scripts = {} - script_order = getattr(module, "script_order", ()) - ordered_scripts = [cls for cls in script_order if is_script(cls)] - unordered_scripts = [cls for _, cls in inspect.getmembers(module, is_script) if cls not in script_order] - - for cls in [*ordered_scripts, *unordered_scripts]: - # For scripts in submodules use the full import path w/o the root module as the name - script_name = cls.full_name.split(".", maxsplit=1)[1] - module_scripts[script_name] = cls - - if module_scripts: - scripts[module_name] = module_scripts - - return scripts - - def get_script(module_name, script_name): """ Retrieve a script class by module and name. Returns None if the script does not exist. diff --git a/netbox/extras/utils.py b/netbox/extras/utils.py index f90858bcf..3a45d20db 100644 --- a/netbox/extras/utils.py +++ b/netbox/extras/utils.py @@ -1,9 +1,15 @@ +import inspect +import sys +import threading + from django.db.models import Q from django.utils.deconstruct import deconstructible from taggit.managers import _TaggableManager from netbox.registry import registry +lock = threading.Lock() + def is_taggable(obj): """ @@ -66,3 +72,48 @@ def register_features(model, features): raise KeyError( f"{feature} is not a valid model feature! Valid keys are: {registry['model_features'].keys()}" ) + + +def get_modules(queryset, litmus_func, ordering_attr): + """ + Returns a list of tuples: + + [ + (module_name, (child, child, ...)), + (module_name, (child, child, ...)), + ... + ] + """ + results = {} + + modules = [mf.get_module_info() for mf in queryset] + modules_bases = set([name.split(".")[0] for _, name, _ in modules]) + + # Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is + # removed from sys.modules while another thread is importing + with lock: + for module_name in list(sys.modules.keys()): + # Everything sharing a base module path with a module in the script folder is removed. + # We also remove all modules with a base module called "scripts". This allows modifying imported + # non-script modules without having to reload the RQ worker. + module_base = module_name.split(".")[0] + if module_base in ('reports', 'scripts', *modules_bases): + del sys.modules[module_name] + + for importer, module_name, _ in modules: + module = importer.find_module(module_name).load_module(module_name) + child_order = getattr(module, ordering_attr, ()) + ordered_children = [cls() for cls in child_order if litmus_func(cls)] + unordered_children = [cls() for _, cls in inspect.getmembers(module, litmus_func) if cls not in child_order] + + children = {} + + for cls in [*ordered_children, *unordered_children]: + # For child objects in submodules use the full import path w/o the root module as the name + child_name = cls.full_name.split(".", maxsplit=1)[1] + children[child_name] = cls + + if children: + results[module_name] = children + + return results diff --git a/netbox/extras/views.py b/netbox/extras/views.py index 8677c505e..091237cfc 100644 --- a/netbox/extras/views.py +++ b/netbox/extras/views.py @@ -942,7 +942,7 @@ class ScriptListView(ContentTypePermissionRequiredMixin, View): def get(self, request): - scripts = get_scripts(use_names=True) + scripts = get_scripts() script_content_type = ContentType.objects.get(app_label='extras', model='script') results = { r.name: r