mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-26 17:26:10 -06:00
Consolidate get_scripts() and get_reports() functions
This commit is contained in:
parent
ba7fab3a5d
commit
ab4a38d32d
@ -1,16 +1,13 @@
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from django_rq import job
|
||||
|
||||
from core.models import ManagedFile
|
||||
from .choices import JobResultStatusChoices, LogLevelChoices
|
||||
from .models import JobResult
|
||||
from .models import JobResult, ReportModule
|
||||
from .utils import get_modules
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -22,6 +19,10 @@ def is_report(obj):
|
||||
return obj in Report.__subclasses__()
|
||||
|
||||
|
||||
def get_reports():
|
||||
return get_modules(ReportModule.objects.all(), is_report, 'report_order')
|
||||
|
||||
|
||||
def get_report(module_name, report_name):
|
||||
"""
|
||||
Return a specific report from within a module.
|
||||
@ -40,41 +41,6 @@ def get_report(module_name, report_name):
|
||||
return report
|
||||
|
||||
|
||||
def get_reports():
|
||||
"""
|
||||
Compile a list of all reports available across all modules in the reports path. Returns a list of tuples:
|
||||
|
||||
[
|
||||
(module_name, (report, report, report, ...)),
|
||||
(module_name, (report, report, report, ...)),
|
||||
...
|
||||
]
|
||||
"""
|
||||
module_list = {}
|
||||
|
||||
# Iterate through all modules within the reports path. These are the user-created files in which reports are
|
||||
# defined.
|
||||
# modules = pkgutil.iter_modules([settings.REPORTS_ROOT])
|
||||
modules = [mf.get_module_info() for mf in ManagedFile.objects.filter(file_root='reports')]
|
||||
for importer, module_name, _ in modules:
|
||||
module = importer.find_module(module_name).load_module(module_name)
|
||||
report_order = getattr(module, "report_order", ())
|
||||
ordered_reports = [cls() for cls in report_order if is_report(cls)]
|
||||
unordered_reports = [cls() for _, cls in inspect.getmembers(module, is_report) if cls not in report_order]
|
||||
|
||||
module_reports = {}
|
||||
|
||||
for cls in [*ordered_reports, *unordered_reports]:
|
||||
# For reports in submodules use the full import path w/o the root module as the name
|
||||
report_name = cls.full_name.split(".", maxsplit=1)[1]
|
||||
module_reports[report_name] = cls
|
||||
|
||||
if module_reports:
|
||||
module_list[module_name] = module_reports
|
||||
|
||||
return module_list
|
||||
|
||||
|
||||
@job('default')
|
||||
def run_report(job_result, *args, **kwargs):
|
||||
"""
|
||||
|
@ -2,9 +2,6 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
|
||||
@ -15,10 +12,9 @@ from django.core.validators import RegexValidator
|
||||
from django.db import transaction
|
||||
from django.utils.functional import classproperty
|
||||
|
||||
from core.models import ManagedFile
|
||||
from extras.api.serializers import ScriptOutputSerializer
|
||||
from extras.choices import JobResultStatusChoices, LogLevelChoices
|
||||
from extras.models import JobResult
|
||||
from extras.models import JobResult, ScriptModule
|
||||
from extras.signals import clear_webhooks
|
||||
from ipam.formfields import IPAddressFormField, IPNetworkFormField
|
||||
from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator
|
||||
@ -26,6 +22,7 @@ from utilities.exceptions import AbortScript, AbortTransaction
|
||||
from utilities.forms import add_blank_choice, DynamicModelChoiceField, DynamicModelMultipleChoiceField
|
||||
from .context_managers import change_logging
|
||||
from .forms import ScriptForm
|
||||
from .utils import get_modules
|
||||
|
||||
__all__ = [
|
||||
'BaseScript',
|
||||
@ -44,8 +41,6 @@ __all__ = [
|
||||
'TextVar',
|
||||
]
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
#
|
||||
# Script variables
|
||||
@ -445,6 +440,10 @@ def is_variable(obj):
|
||||
return isinstance(obj, ScriptVariable)
|
||||
|
||||
|
||||
def get_scripts():
|
||||
return get_modules(ScriptModule.objects.all(), is_script, 'script_order')
|
||||
|
||||
|
||||
def run_script(data, request, commit=True, *args, **kwargs):
|
||||
"""
|
||||
A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It
|
||||
@ -523,52 +522,6 @@ def run_script(data, request, commit=True, *args, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
def get_scripts(use_names=False):
|
||||
"""
|
||||
Return a dict of dicts mapping all scripts to their modules. Set use_names to True to use each module's human-
|
||||
defined name in place of the actual module name.
|
||||
"""
|
||||
scripts = {}
|
||||
|
||||
# Get all modules within the scripts path. These are the user-created files in which scripts are
|
||||
# defined.
|
||||
# modules = list(pkgutil.iter_modules([settings.SCRIPTS_ROOT]))
|
||||
modules = [mf.get_module_info() for mf in ManagedFile.objects.filter(file_root='scripts')]
|
||||
modules_bases = set([name.split(".")[0] for _, name, _ in modules])
|
||||
|
||||
# Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is
|
||||
# removed from sys.modules while another thread is importing
|
||||
with lock:
|
||||
for module_name in list(sys.modules.keys()):
|
||||
# Everything sharing a base module path with a module in the script folder is removed.
|
||||
# We also remove all modules with a base module called "scripts". This allows modifying imported
|
||||
# non-script modules without having to reload the RQ worker.
|
||||
module_base = module_name.split(".")[0]
|
||||
if module_base == "scripts" or module_base in modules_bases:
|
||||
del sys.modules[module_name]
|
||||
|
||||
for importer, module_name, _ in modules:
|
||||
module = importer.find_module(module_name).load_module(module_name)
|
||||
|
||||
if use_names and hasattr(module, 'name'):
|
||||
module_name = module.name
|
||||
|
||||
module_scripts = {}
|
||||
script_order = getattr(module, "script_order", ())
|
||||
ordered_scripts = [cls for cls in script_order if is_script(cls)]
|
||||
unordered_scripts = [cls for _, cls in inspect.getmembers(module, is_script) if cls not in script_order]
|
||||
|
||||
for cls in [*ordered_scripts, *unordered_scripts]:
|
||||
# For scripts in submodules use the full import path w/o the root module as the name
|
||||
script_name = cls.full_name.split(".", maxsplit=1)[1]
|
||||
module_scripts[script_name] = cls
|
||||
|
||||
if module_scripts:
|
||||
scripts[module_name] = module_scripts
|
||||
|
||||
return scripts
|
||||
|
||||
|
||||
def get_script(module_name, script_name):
|
||||
"""
|
||||
Retrieve a script class by module and name. Returns None if the script does not exist.
|
||||
|
@ -1,9 +1,15 @@
|
||||
import inspect
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from django.db.models import Q
|
||||
from django.utils.deconstruct import deconstructible
|
||||
from taggit.managers import _TaggableManager
|
||||
|
||||
from netbox.registry import registry
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
def is_taggable(obj):
|
||||
"""
|
||||
@ -66,3 +72,48 @@ def register_features(model, features):
|
||||
raise KeyError(
|
||||
f"{feature} is not a valid model feature! Valid keys are: {registry['model_features'].keys()}"
|
||||
)
|
||||
|
||||
|
||||
def get_modules(queryset, litmus_func, ordering_attr):
|
||||
"""
|
||||
Returns a list of tuples:
|
||||
|
||||
[
|
||||
(module_name, (child, child, ...)),
|
||||
(module_name, (child, child, ...)),
|
||||
...
|
||||
]
|
||||
"""
|
||||
results = {}
|
||||
|
||||
modules = [mf.get_module_info() for mf in queryset]
|
||||
modules_bases = set([name.split(".")[0] for _, name, _ in modules])
|
||||
|
||||
# Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is
|
||||
# removed from sys.modules while another thread is importing
|
||||
with lock:
|
||||
for module_name in list(sys.modules.keys()):
|
||||
# Everything sharing a base module path with a module in the script folder is removed.
|
||||
# We also remove all modules with a base module called "scripts". This allows modifying imported
|
||||
# non-script modules without having to reload the RQ worker.
|
||||
module_base = module_name.split(".")[0]
|
||||
if module_base in ('reports', 'scripts', *modules_bases):
|
||||
del sys.modules[module_name]
|
||||
|
||||
for importer, module_name, _ in modules:
|
||||
module = importer.find_module(module_name).load_module(module_name)
|
||||
child_order = getattr(module, ordering_attr, ())
|
||||
ordered_children = [cls() for cls in child_order if litmus_func(cls)]
|
||||
unordered_children = [cls() for _, cls in inspect.getmembers(module, litmus_func) if cls not in child_order]
|
||||
|
||||
children = {}
|
||||
|
||||
for cls in [*ordered_children, *unordered_children]:
|
||||
# For child objects in submodules use the full import path w/o the root module as the name
|
||||
child_name = cls.full_name.split(".", maxsplit=1)[1]
|
||||
children[child_name] = cls
|
||||
|
||||
if children:
|
||||
results[module_name] = children
|
||||
|
||||
return results
|
||||
|
@ -942,7 +942,7 @@ class ScriptListView(ContentTypePermissionRequiredMixin, View):
|
||||
|
||||
def get(self, request):
|
||||
|
||||
scripts = get_scripts(use_names=True)
|
||||
scripts = get_scripts()
|
||||
script_content_type = ContentType.objects.get(app_label='extras', model='script')
|
||||
results = {
|
||||
r.name: r
|
||||
|
Loading…
Reference in New Issue
Block a user