diff --git a/netbox/extras/models/mixins.py b/netbox/extras/models/mixins.py index 3a7273f93..eb017302a 100644 --- a/netbox/extras/models/mixins.py +++ b/netbox/extras/models/mixins.py @@ -131,7 +131,7 @@ class RenderTemplateMixin(models.Model): """ context = self.get_context(context=context, queryset=queryset) env_params = self.environment_params or {} - output = render_jinja2(self.template_code, context, env_params) + output = render_jinja2(self.template_code, context, env_params, getattr(self, 'data_file', None)) # Replace CRLF-style line terminators output = output.replace('\r\n', '\n') diff --git a/netbox/extras/tests/test_models.py b/netbox/extras/tests/test_models.py index 089e47c02..6b718569c 100644 --- a/netbox/extras/tests/test_models.py +++ b/netbox/extras/tests/test_models.py @@ -1,9 +1,12 @@ -from django.forms import ValidationError -from django.test import TestCase +import tempfile +from pathlib import Path -from core.models import ObjectType +from django.forms import ValidationError +from django.test import tag, TestCase + +from core.models import DataSource, ObjectType from dcim.models import Device, DeviceRole, DeviceType, Location, Manufacturer, Platform, Region, Site, SiteGroup -from extras.models import ConfigContext, Tag +from extras.models import ConfigContext, ConfigTemplate, Tag from tenancy.models import Tenant, TenantGroup from utilities.exceptions import AbortRequest from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine @@ -33,8 +36,8 @@ class TagTest(TestCase): ] site = Site.objects.create(name='Site 1') - for tag in tags: - site.tags.add(tag) + for _tag in tags: + site.tags.add(_tag) site.save() site = Site.objects.first() @@ -540,3 +543,66 @@ class ConfigContextTest(TestCase): device.local_context_data = 'foo' with self.assertRaises(ValidationError): device.clean() + + +class ConfigTemplateTest(TestCase): + """ + TODO: These test cases deal with the weighting, ordering, and deep merge logic of config context data. + """ + MAIN_TEMPLATE = """ + {%- include 'base.j2' %} + """.strip() + BASE_TEMPLATE = """ + Hi + """.strip() + + @classmethod + def _create_template_file(cls, templates_dir, file_name, content): + template_file_name = file_name + if not template_file_name.endswith('j2'): + template_file_name += '.j2' + temp_file_path = templates_dir / template_file_name + + with open(temp_file_path, 'w') as f: + f.write(content) + + @classmethod + def setUpTestData(cls): + temp_dir = tempfile.TemporaryDirectory() + templates_dir = Path(temp_dir.name) / "templates" + templates_dir.mkdir(parents=True, exist_ok=True) + + cls._create_template_file(templates_dir, 'base.j2', cls.BASE_TEMPLATE) + cls._create_template_file(templates_dir, 'main.j2', cls.MAIN_TEMPLATE) + + data_source = DataSource( + name="Test DataSource", + type="local", + source_url=str(templates_dir), + ) + data_source.save() + data_source.sync() + + base_config_template = ConfigTemplate( + name="BaseTemplate", + data_file=data_source.datafiles.filter(path__endswith='base.j2').first() + ) + base_config_template.clean() + base_config_template.save() + cls.base_config_template = base_config_template + + main_config_template = ConfigTemplate( + name="MainTemplate", + data_file=data_source.datafiles.filter(path__endswith='main.j2').first() + ) + main_config_template.clean() + main_config_template.save() + cls.main_config_template = main_config_template + + @tag('regression') + def test_config_template_with_data_source(self): + self.assertEqual(self.BASE_TEMPLATE, self.base_config_template.render({})) + + @tag('regression') + def test_config_template_with_data_source_nested_templates(self): + self.assertEqual(self.BASE_TEMPLATE, self.main_config_template.render({})) diff --git a/netbox/utilities/jinja2.py b/netbox/utilities/jinja2.py index 37b3b2dfb..362bc2393 100644 --- a/netbox/utilities/jinja2.py +++ b/netbox/utilities/jinja2.py @@ -49,11 +49,27 @@ class DataFileLoader(BaseLoader): # Utility functions # -def render_jinja2(template_code, context, environment_params=None): +def render_jinja2(template_code, context, environment_params=None, data_file=None): """ Render a Jinja2 template with the provided context. Return the rendered content. """ environment_params = environment_params or {} + + if 'loader' not in environment_params: + if data_file: + loader = DataFileLoader(data_file.source) + loader.cache_templates({ + data_file.path: template_code + }) + else: + loader = BaseLoader() + environment_params['loader'] = loader + environment = SandboxedEnvironment(**environment_params) environment.filters.update(get_config().JINJA2_FILTERS) - return environment.from_string(source=template_code).render(**context) + + if data_file: + template = environment.get_template(data_file.path) + else: + template = environment.from_string(source=template_code) + return template.render(**context)