diff --git a/netbox/extras/scripts.py b/netbox/extras/scripts.py index ad9e5bcc4..3ba7792a1 100644 --- a/netbox/extras/scripts.py +++ b/netbox/extras/scripts.py @@ -61,7 +61,7 @@ class ScriptVariable: self.field_attrs['label'] = label if description: self.field_attrs['help_text'] = description - if default: + if default is not None: self.field_attrs['initial'] = default if widget: self.field_attrs['widget'] = widget diff --git a/netbox/extras/tests/test_views.py b/netbox/extras/tests/test_views.py index 91444e2ce..727f0f803 100644 --- a/netbox/extras/tests/test_views.py +++ b/netbox/extras/tests/test_views.py @@ -1,6 +1,7 @@ from django.contrib.contenttypes.models import ContentType from django.urls import reverse from django.test import tag +from unittest.mock import patch, PropertyMock from core.choices import ManagedFileRootPathChoices from core.events import * @@ -906,7 +907,7 @@ class ScriptValidationErrorTest(TestCase): user_permissions = ['extras.view_script', 'extras.run_script'] class TestScriptMixin: - bar = IntegerVar(min_value=0, max_value=30, default=30) + bar = IntegerVar(min_value=0, max_value=30) class TestScriptClass(TestScriptMixin, PythonClass): class Meta: @@ -930,8 +931,6 @@ class ScriptValidationErrorTest(TestCase): @tag('regression') def test_script_validation_error_displays_message(self): - from unittest.mock import patch - url = reverse('extras:script', kwargs={'pk': self.script.pk}) with patch('extras.views.get_workers_for_queue', return_value=['worker']): @@ -944,8 +943,6 @@ class ScriptValidationErrorTest(TestCase): @tag('regression') def test_script_validation_error_no_toast_for_fieldset_fields(self): - from unittest.mock import patch, PropertyMock - class FieldsetScript(PythonClass): class Meta: name = 'Fieldset test' @@ -967,3 +964,42 @@ class ScriptValidationErrorTest(TestCase): self.assertEqual(response.status_code, 200) messages = list(response.context['messages']) self.assertEqual(len(messages), 0) + + +class ScriptDefaultValuesTest(TestCase): + user_permissions = ['extras.view_script', 'extras.run_script'] + + class TestScriptClass(PythonClass): + class Meta: + name = 'Test script' + commit_default = False + + bool_default_true = BooleanVar(default=True) + bool_default_false = BooleanVar(default=False) + int_with_default = IntegerVar(default=0) + int_without_default = IntegerVar(required=False) + + def run(self, data, commit): + return "Complete" + + @classmethod + def setUpTestData(cls): + module = ScriptModule.objects.create(file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='test_script.py') + cls.script = Script.objects.create(module=module, name='Test script', is_executable=True) + + def setUp(self): + super().setUp() + Script.python_class = property(lambda self: ScriptDefaultValuesTest.TestScriptClass) + + def test_default_values_are_used(self): + url = reverse('extras:script', kwargs={'pk': self.script.pk}) + + with patch('extras.views.get_workers_for_queue', return_value=['worker']): + with patch('extras.jobs.ScriptJob.enqueue') as mock_enqueue: + mock_enqueue.return_value.pk = 1 + self.client.post(url, {}) + call_kwargs = mock_enqueue.call_args.kwargs + self.assertEqual(call_kwargs['data']['bool_default_true'], True) + self.assertEqual(call_kwargs['data']['bool_default_false'], False) + self.assertEqual(call_kwargs['data']['int_with_default'], 0) + self.assertIsNone(call_kwargs['data']['int_without_default']) diff --git a/netbox/extras/views.py b/netbox/extras/views.py index 3c1fc395d..461ed423f 100644 --- a/netbox/extras/views.py +++ b/netbox/extras/views.py @@ -1511,7 +1511,13 @@ class ScriptView(BaseScriptView): 'script': script, }) - form = script_class.as_form(request.POST, request.FILES) + # Populate missing variables with their default values, if defined + post_data = request.POST.copy() + for name, var in script_class._get_vars().items(): + if name not in post_data and (initial := var.field_attrs.get('initial')) is not None: + post_data[name] = initial + + form = script_class.as_form(post_data, request.FILES) # Allow execution only if RQ worker process is running if not get_workers_for_queue('default'):