diff --git a/netbox/utilities/testing/testcases.py b/netbox/utilities/testing/testcases.py index ef9660fa7..90138c3d9 100644 --- a/netbox/utilities/testing/testcases.py +++ b/netbox/utilities/testing/testcases.py @@ -1,7 +1,7 @@ from django.contrib.auth.models import Permission, User from django.core.exceptions import ObjectDoesNotExist from django.test import Client, TestCase as _TestCase -from django.urls import reverse +from django.urls import reverse, NoReverseMatch from rest_framework.test import APIClient from users.models import Token @@ -88,15 +88,44 @@ class StandardTestCases: form_data = {} csv_data = {} + maxDiff = None + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.model is not None: - self.base_url_name = '{}:{}_{{}}'.format(self.model._meta.app_label, self.model._meta.model_name) + if self.model is None: + raise Exception("Test case requires model to be defined") + + def _get_url(self, action, instance=None): + """ + Return the URL name for a specific action. An instance must be specified for + get/edit/delete views. + """ + url_format = '{}:{}_{{}}'.format( + self.model._meta.app_label, + self.model._meta.model_name + ) + + if action in ('list', 'add', 'import'): + return reverse(url_format.format(action)) + + elif action in ('get', 'edit', 'delete'): + if instance is None: + raise Exception("Resolving {} URL requires specifying an instance".format(action)) + # Attempt to resolve using slug first + if hasattr(self.model, 'slug'): + try: + return reverse(url_format.format(action), kwargs={'slug': instance.slug}) + except NoReverseMatch: + pass + return reverse(url_format.format(action), kwargs={'pk': instance.pk}) + + else: + raise Exception("Invalid action for URL resolution: {}".format(action)) def test_list_objects(self): - response = self.client.get(reverse(self.base_url_name.format('list'))) + response = self.client.get(self._get_url('list')) self.assertHttpStatus(response, 200) def test_get_object(self): @@ -107,7 +136,7 @@ class StandardTestCases: def test_create_object(self): initial_count = self.model.objects.count() request = { - 'path': reverse(self.base_url_name.format('add')), + 'path': self._get_url('add'), 'data': post_data(self.form_data), 'follow': True, } @@ -128,14 +157,8 @@ class StandardTestCases: def test_edit_object(self): instance = self.model.objects.first() - # Determine the proper kwargs to pass to the edit URL - if hasattr(instance, 'slug'): - kwargs = {'slug': instance.slug} - else: - kwargs = {'pk': instance.pk} - request = { - 'path': reverse(self.base_url_name.format('edit'), kwargs=kwargs), + 'path': self._get_url('edit', instance), 'data': post_data(self.form_data), 'follow': True, } @@ -155,14 +178,8 @@ class StandardTestCases: def test_delete_object(self): instance = self.model.objects.first() - # Determine the proper kwargs to pass to the deletion URL - if hasattr(instance, 'slug'): - kwargs = {'slug': instance.slug} - else: - kwargs = {'pk': instance.pk} - request = { - 'path': reverse(self.base_url_name.format('delete'), kwargs=kwargs), + 'path': self._get_url('delete', instance), 'data': {'confirm': True}, 'follow': True, } @@ -182,7 +199,7 @@ class StandardTestCases: def test_import_objects(self): initial_count = self.model.objects.count() request = { - 'path': reverse(self.base_url_name.format('import')), + 'path': self._get_url('import'), 'data': { 'csv': '\n'.join(self.csv_data) }