Remove get_modules() utility function

This commit is contained in:
jeremystretch 2023-03-24 11:07:53 -04:00
parent f5830c1cd8
commit 107c46cb7a
8 changed files with 170 additions and 261 deletions

View File

@ -16,8 +16,8 @@ from extras import filtersets
from extras.choices import JobResultStatusChoices from extras.choices import JobResultStatusChoices
from extras.models import * from extras.models import *
from extras.models import CustomField from extras.models import CustomField
from extras.reports import get_report, get_reports, run_report from extras.reports import get_report, run_report
from extras.scripts import get_script, get_scripts, run_script from extras.scripts import get_script, run_script
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.features import SyncedDataMixin from netbox.api.features import SyncedDataMixin
from netbox.api.metadata import ContentTypeMetadata from netbox.api.metadata import ContentTypeMetadata
@ -27,7 +27,6 @@ from utilities.exceptions import RQWorkerNotRunningException
from utilities.utils import copy_safe_request, count_related from utilities.utils import copy_safe_request, count_related
from . import serializers from . import serializers
from .mixins import ConfigTemplateRenderMixin from .mixins import ConfigTemplateRenderMixin
from .nested_serializers import NestedConfigTemplateSerializer
class ExtrasRootView(APIRootView): class ExtrasRootView(APIRootView):
@ -189,7 +188,6 @@ class ReportViewSet(ViewSet):
""" """
Compile all reports and their related results (if any). Result data is deferred in the list view. Compile all reports and their related results (if any). Result data is deferred in the list view.
""" """
report_list = []
report_content_type = ContentType.objects.get(app_label='extras', model='report') report_content_type = ContentType.objects.get(app_label='extras', model='report')
results = { results = {
r.name: r r.name: r
@ -199,13 +197,13 @@ class ReportViewSet(ViewSet):
).order_by('name', '-created').distinct('name').defer('data') ).order_by('name', '-created').distinct('name').defer('data')
} }
# Iterate through all available Reports. report_list = []
for module_name, reports in get_reports().items(): for report_module in ReportModule.objects.restrict(request.user):
for report in reports.values(): report_list.extend([report() for report in report_module.reports.values()])
# Attach the relevant JobResult (if any) to each Report. # Attach JobResult objects to each report (if any)
for report in report_list:
report.result = results.get(report.full_name, None) report.result = results.get(report.full_name, None)
report_list.append(report)
serializer = serializers.ReportSerializer(report_list, many=True, context={ serializer = serializers.ReportSerializer(report_list, many=True, context={
'request': request, 'request': request,
@ -296,15 +294,15 @@ class ScriptViewSet(ViewSet):
).order_by('name', '-created').distinct('name').defer('data') ).order_by('name', '-created').distinct('name').defer('data')
} }
flat_list = [] script_list = []
for script_list in get_scripts().values(): for script_module in ScriptModule.objects.restrict(request.user):
flat_list.extend(script_list.values()) script_list.extend(script_module.scripts.values())
# Attach JobResult objects to each script (if any) # Attach JobResult objects to each script (if any)
for script in flat_list: for script in script_list:
script.result = results.get(script.full_name, None) script.result = results.get(script.full_name, None)
serializer = serializers.ScriptSerializer(flat_list, many=True, context={'request': request}) serializer = serializers.ScriptSerializer(script_list, many=True, context={'request': request})
return Response(serializer.data) return Response(serializer.data)

View File

@ -5,8 +5,8 @@ from django.core.management.base import BaseCommand
from django.utils import timezone from django.utils import timezone
from extras.choices import JobResultStatusChoices from extras.choices import JobResultStatusChoices
from extras.models import JobResult from extras.models import JobResult, ReportModule
from extras.reports import get_reports, run_report from extras.reports import run_report
class Command(BaseCommand): class Command(BaseCommand):
@ -17,13 +17,9 @@ class Command(BaseCommand):
def handle(self, *args, **options): def handle(self, *args, **options):
# Gather all available reports for module in ReportModule.objects.all():
reports = get_reports() for report in module.reports.values():
if module.name in options['reports'] or report.full_name in options['reports']:
# Run reports
for module_name, report_list in reports.items():
for report in report_list.values():
if module_name in options['reports'] or report.full_name in options['reports']:
# Run the report and create a new JobResult # Run the report and create a new JobResult
self.stdout.write( self.stdout.write(

View File

@ -7,32 +7,16 @@ from django_rq import job
from .choices import JobResultStatusChoices, LogLevelChoices from .choices import JobResultStatusChoices, LogLevelChoices
from .models import JobResult, ReportModule from .models import JobResult, ReportModule
from .temp import is_report
from .utils import get_modules
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_reports():
return get_modules(ReportModule.objects.all(), is_report, 'report_order')
def get_report(module_name, report_name): def get_report(module_name, report_name):
""" """
Return a specific report from within a module. Return a specific report from within a module.
""" """
reports = get_reports() module = ReportModule.objects.get(file_path=f'{module_name}.py')
module = reports.get(module_name) return module.scripts.get(report_name)
if module is None:
return None
report = module.get(report_name)
if report is None:
return None
return report
@job('default') @job('default')

View File

@ -22,8 +22,6 @@ from utilities.exceptions import AbortScript, AbortTransaction
from utilities.forms import add_blank_choice, DynamicModelChoiceField, DynamicModelMultipleChoiceField from utilities.forms import add_blank_choice, DynamicModelChoiceField, DynamicModelMultipleChoiceField
from .context_managers import change_logging from .context_managers import change_logging
from .forms import ScriptForm from .forms import ScriptForm
from .temp import is_script
from .utils import get_modules
__all__ = [ __all__ = [
'BaseScript', 'BaseScript',
@ -432,10 +430,6 @@ def is_variable(obj):
return isinstance(obj, ScriptVariable) return isinstance(obj, ScriptVariable)
def get_scripts():
return get_modules(ScriptModule.objects.all(), is_script, 'script_order')
def run_script(data, request, commit=True, *args, **kwargs): def run_script(data, request, commit=True, *args, **kwargs):
""" """
A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It A wrapper for calling Script.run(). This performs error handling and provides a hook for committing changes. It
@ -444,10 +438,10 @@ def run_script(data, request, commit=True, *args, **kwargs):
job_result = kwargs.pop('job_result') job_result = kwargs.pop('job_result')
job_result.start() job_result.start()
module, script_name = job_result.name.split('.', 1) module_name, script_name = job_result.name.split('.', 1)
script = get_script(module, script_name)() script = get_script(module_name, script_name)()
logger = logging.getLogger(f"netbox.scripts.{module}.{script_name}") logger = logging.getLogger(f"netbox.scripts.{module_name}.{script_name}")
logger.info(f"Running script (commit={commit})") logger.info(f"Running script (commit={commit})")
# Add files to form data # Add files to form data
@ -518,7 +512,5 @@ def get_script(module_name, script_name):
""" """
Retrieve a script class by module and name. Returns None if the script does not exist. Retrieve a script class by module and name. Returns None if the script does not exist.
""" """
scripts = get_scripts() module = ScriptModule.objects.get(file_path=f'{module_name}.py')
module = scripts.get(module_name) return module.scripts.get(script_name)
if module:
return module.get(script_name)

View File

@ -1,5 +1,3 @@
import inspect
import sys
import threading import threading
from django.db.models import Q from django.db.models import Q
@ -72,48 +70,3 @@ def register_features(model, features):
raise KeyError( raise KeyError(
f"{feature} is not a valid model feature! Valid keys are: {registry['model_features'].keys()}" f"{feature} is not a valid model feature! Valid keys are: {registry['model_features'].keys()}"
) )
def get_modules(queryset, litmus_func, ordering_attr):
"""
Returns a list of tuples:
[
(module, (child, child, ...)),
(module, (child, child, ...)),
...
]
"""
results = {}
modules = [(mf, *mf.get_module_info()) for mf in queryset]
modules_bases = set([name.split(".")[0] for _, _, name, _ in modules])
# Deleting from sys.modules needs to done behind a lock to prevent race conditions where a module is
# removed from sys.modules while another thread is importing
with lock:
for module_name in list(sys.modules.keys()):
# Everything sharing a base module path with a module in the script folder is removed.
# We also remove all modules with a base module called "scripts". This allows modifying imported
# non-script modules without having to reload the RQ worker.
module_base = module_name.split(".")[0]
if module_base in ('reports', 'scripts', *modules_bases):
del sys.modules[module_name]
for mf, importer, module_name, _ in modules:
module = importer.find_module(module_name).load_module(module_name)
child_order = getattr(module, ordering_attr, ())
ordered_children = [cls() for cls in child_order if litmus_func(cls)]
unordered_children = [cls() for _, cls in inspect.getmembers(module, litmus_func) if cls not in child_order]
children = {}
for cls in [*ordered_children, *unordered_children]:
# For child objects in submodules use the full import path w/o the root module as the name
child_name = cls.full_name.split(".", maxsplit=1)[1]
children[child_name] = cls
if children:
results[mf] = children
return results

View File

@ -22,8 +22,8 @@ from .choices import JobResultStatusChoices
from .constants import SCRIPTS_ROOT_NAME, REPORTS_ROOT_NAME from .constants import SCRIPTS_ROOT_NAME, REPORTS_ROOT_NAME
from .forms.reports import ReportForm from .forms.reports import ReportForm
from .models import * from .models import *
from .reports import get_report, get_reports, run_report from .reports import get_report, run_report
from .scripts import get_scripts, run_script from .scripts import run_script
# #
@ -809,16 +809,16 @@ class ReportModuleDeleteView(generic.ObjectDeleteView):
class ReportListView(ContentTypePermissionRequiredMixin, View): class ReportListView(ContentTypePermissionRequiredMixin, View):
""" """
Retrieve all of the available reports from disk and the recorded JobResult (if any) for each. Retrieve all the available reports from disk and the recorded JobResult (if any) for each.
""" """
def get_required_permission(self): def get_required_permission(self):
return 'extras.view_report' return 'extras.view_report'
def get(self, request): def get(self, request):
report_modules = ReportModule.objects.restrict(request.user)
reports = get_reports()
report_content_type = ContentType.objects.get(app_label='extras', model='report') report_content_type = ContentType.objects.get(app_label='extras', model='report')
results = { job_results = {
r.name: r r.name: r
for r in JobResult.objects.filter( for r in JobResult.objects.filter(
obj_type=report_content_type, obj_type=report_content_type,
@ -826,18 +826,17 @@ class ReportListView(ContentTypePermissionRequiredMixin, View):
).order_by('name', '-created').distinct('name').defer('data') ).order_by('name', '-created').distinct('name').defer('data')
} }
ret = [] # for module, report_list in reports.items():
# module_reports = []
for module, report_list in reports.items(): # for report in report_list.values():
module_reports = [] # report.result = results.get(report.full_name, None)
for report in report_list.values(): # module_reports.append(report)
report.result = results.get(report.full_name, None) # ret.append((module, module_reports))
module_reports.append(report)
ret.append((module, module_reports))
return render(request, 'extras/report_list.html', { return render(request, 'extras/report_list.html', {
'model': ReportModule, 'model': ScriptModule,
'reports': ret, 'report_modules': report_modules,
'job_results': job_results,
}) })
@ -955,27 +954,16 @@ class ScriptModuleDeleteView(generic.ObjectDeleteView):
queryset = ScriptModule.objects.all() queryset = ScriptModule.objects.all()
class GetScriptMixin:
def _get_script(self, name, module=None):
if module is None:
module, name = name.split('.', 1)
scripts = get_scripts()
try:
return scripts[module][name]()
except KeyError:
raise Http404
class ScriptListView(ContentTypePermissionRequiredMixin, View): class ScriptListView(ContentTypePermissionRequiredMixin, View):
def get_required_permission(self): def get_required_permission(self):
return 'extras.view_script' return 'extras.view_script'
def get(self, request): def get(self, request):
script_modules = ScriptModule.objects.restrict(request.user)
scripts = get_scripts()
script_content_type = ContentType.objects.get(app_label='extras', model='script') script_content_type = ContentType.objects.get(app_label='extras', model='script')
results = { job_results = {
r.name: r r.name: r
for r in JobResult.objects.filter( for r in JobResult.objects.filter(
obj_type=script_content_type, obj_type=script_content_type,
@ -983,17 +971,18 @@ class ScriptListView(ContentTypePermissionRequiredMixin, View):
).order_by('name', '-created').distinct('name').defer('data') ).order_by('name', '-created').distinct('name').defer('data')
} }
for _scripts in scripts.values(): # for _scripts in scripts.values():
for script in _scripts.values(): # for script in _scripts.values():
script.result = results.get(script.full_name) # script.result = results.get(script.full_name)
return render(request, 'extras/script_list.html', { return render(request, 'extras/script_list.html', {
'model': ScriptModule, 'model': ScriptModule,
'scripts': scripts, 'script_modules': script_modules,
'job_results': job_results,
}) })
class ScriptView(ContentTypePermissionRequiredMixin, GetScriptMixin, View): class ScriptView(ContentTypePermissionRequiredMixin, View):
def get_required_permission(self): def get_required_permission(self):
return 'extras.view_script' return 'extras.view_script'
@ -1018,12 +1007,11 @@ class ScriptView(ContentTypePermissionRequiredMixin, GetScriptMixin, View):
}) })
def post(self, request, module, name): def post(self, request, module, name):
# Permissions check
if not request.user.has_perm('extras.run_script'): if not request.user.has_perm('extras.run_script'):
return HttpResponseForbidden() return HttpResponseForbidden()
script = self._get_script(name, module) module = get_object_or_404(ScriptModule.objects.restrict(request.user), file_path=f'{module}.py')
script = module.scripts[name]()
form = script.as_form(request.POST, request.FILES) form = script.as_form(request.POST, request.FILES)
# Allow execution only if RQ worker process is running # Allow execution only if RQ worker process is running
@ -1053,7 +1041,7 @@ class ScriptView(ContentTypePermissionRequiredMixin, GetScriptMixin, View):
}) })
class ScriptResultView(ContentTypePermissionRequiredMixin, GetScriptMixin, View): class ScriptResultView(ContentTypePermissionRequiredMixin, View):
def get_required_permission(self): def get_required_permission(self):
return 'extras.view_script' return 'extras.view_script'
@ -1064,7 +1052,9 @@ class ScriptResultView(ContentTypePermissionRequiredMixin, GetScriptMixin, View)
if result.obj_type != script_content_type: if result.obj_type != script_content_type:
raise Http404 raise Http404
script = self._get_script(result.name) module_name, script_name = result.name.split('.', 1)
module = get_object_or_404(ScriptModule.objects.restrict(request.user), file_path=f'{module_name}.py')
script = module.scripts[script_name]()
# If this is an HTMX request, return only the result HTML # If this is an HTMX request, return only the result HTML
if is_htmx(request): if is_htmx(request):

View File

@ -24,8 +24,7 @@
{% block content-wrapper %} {% block content-wrapper %}
<div class="tab-content"> <div class="tab-content">
{% if reports %} {% for module in report_modules %}
{% for module, module_reports in reports %}
<div class="card"> <div class="card">
<h5 class="card-header"> <h5 class="card-header">
{% if perms.extras.delete_reportmodule %} {% if perms.extras.delete_reportmodule %}
@ -50,7 +49,7 @@
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{% for report in module_reports %} {% for report_name, report in module.reports.items %}
<tr> <tr>
<td> <td>
<a href="{% url 'extras:report' module=report.module name=report.class_name %}" id="{{ report.module }}.{{ report.class_name }}">{{ report.name }}</a> <a href="{% url 'extras:report' module=report.module name=report.class_name %}" id="{{ report.module }}.{{ report.class_name }}">{{ report.name }}</a>
@ -101,14 +100,13 @@
</table> </table>
</div> </div>
</div> </div>
{% endfor %} {% empty %}
{% else %}
<div class="alert alert-info" role="alert"> <div class="alert alert-info" role="alert">
<h4 class="alert-heading">No Reports Found</h4> <h4 class="alert-heading">No Reports Found</h4>
Reports should be saved to <code>{{ settings.REPORTS_ROOT }}</code>. Reports should be saved to <code>{{ settings.REPORTS_ROOT }}</code>.
<hr/> <hr/>
<small>This path can be changed by setting <code>REPORTS_ROOT</code> in NetBox's configuration.</small> <small>This path can be changed by setting <code>REPORTS_ROOT</code> in NetBox's configuration.</small>
</div> </div>
{% endif %} {% endfor %}
</div> </div>
{% endblock content-wrapper %} {% endblock content-wrapper %}

View File

@ -23,8 +23,7 @@
{% block content-wrapper %} {% block content-wrapper %}
<div class="tab-content"> <div class="tab-content">
{% if scripts %} {% for module in script_modules %}
{% for module, module_scripts in scripts.items %}
<div class="card"> <div class="card">
<h5 class="card-header"> <h5 class="card-header">
{% if perms.extras.delete_scriptmodule %} {% if perms.extras.delete_scriptmodule %}
@ -48,20 +47,20 @@
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{% for class_name, script in module_scripts.items %} {% for script_name, script_class in module.scripts.items %}
<tr> <tr>
<td> <td>
<a href="{% url 'extras:script' module=script.root_module name=class_name %}" name="script.{{ class_name }}">{{ script.name }}</a> <a href="{% url 'extras:script' module=script_class.root_module name=script_name %}" name="script.{{ script_name }}">{{ script_class.name }}</a>
</td> </td>
<td> <td>
{% include 'extras/inc/job_label.html' with result=script.result %} {% include 'extras/inc/job_label.html' with result=script_class.result %}
</td> </td>
<td> <td>
{{ script.Meta.description|markdown|placeholder }} {{ script_class.Meta.description|markdown|placeholder }}
</td> </td>
{% if script.result %} {% if script_class.result %}
<td class="text-end"> <td class="text-end">
<a href="{% url 'extras:script_result' job_result_pk=script.result.pk %}">{{ script.result.created|annotated_date }}</a> <a href="{% url 'extras:script_result' job_result_pk=script_class.result.pk %}">{{ script_class.result.created|annotated_date }}</a>
</td> </td>
{% else %} {% else %}
<td class="text-end text-muted">Never</td> <td class="text-end text-muted">Never</td>
@ -72,14 +71,13 @@
</table> </table>
</div> </div>
</div> </div>
{% endfor %} {% empty %}
{% else %}
<div class="alert alert-info"> <div class="alert alert-info">
<h4 class="alert-heading">No Scripts Found</h4> <h4 class="alert-heading">No Scripts Found</h4>
Scripts should be saved to <code>{{ settings.SCRIPTS_ROOT }}</code>. Scripts should be saved to <code>{{ settings.SCRIPTS_ROOT }}</code>.
<hr/> <hr/>
This path can be changed by setting <code>SCRIPTS_ROOT</code> in NetBox's configuration. This path can be changed by setting <code>SCRIPTS_ROOT</code> in NetBox's configuration.
</div> </div>
{% endif %} {% endfor %}
</div> </div>
{% endblock content-wrapper %} {% endblock content-wrapper %}