From ea184b66b6786fe107f5350f2b339dba058d55b6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 09:29:06 -0700 Subject: [PATCH] 12794 call get_user_model once in tests --- netbox/dcim/tests/test_filtersets.py | 13 ++++--- netbox/dcim/tests/test_views.py | 7 +++- .../extras/management/commands/runscript.py | 12 +++--- netbox/extras/tests/test_filtersets.py | 33 +++++++++-------- netbox/extras/tests/test_views.py | 17 +++++---- netbox/netbox/tests/test_authentication.py | 21 ++++++----- netbox/users/tests/test_api.py | 25 +++++++------ netbox/users/tests/test_filtersets.py | 37 ++++++++++--------- netbox/users/tests/test_models.py | 13 ++++--- 9 files changed, 102 insertions(+), 76 deletions(-) diff --git a/netbox/dcim/tests/test_filtersets.py b/netbox/dcim/tests/test_filtersets.py index f55be2f3f..a1e684cb9 100644 --- a/netbox/dcim/tests/test_filtersets.py +++ b/netbox/dcim/tests/test_filtersets.py @@ -12,6 +12,9 @@ from virtualization.models import Cluster, ClusterType from wireless.choices import WirelessChannelChoices, WirelessRoleChoices +User = get_user_model() + + class DeviceComponentFilterSetTests: def test_device_type(self): @@ -593,11 +596,11 @@ class RackReservationTestCase(TestCase, ChangeLoggedFilterSetTests): Rack.objects.bulk_create(racks) users = ( - get_user_model()(username='User 1'), - get_user_model()(username='User 2'), - get_user_model()(username='User 3'), + User(username='User 1'), + User(username='User 2'), + User(username='User 3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) tenant_groups = ( TenantGroup(name='Tenant group 1', slug='tenant-group-1'), @@ -650,7 +653,7 @@ class RackReservationTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_user(self): - users = get_user_model().objects.all()[:2] + users = User.objects.all()[:2] params = {'user_id': [users[0].pk, users[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user': [users[0].username, users[1].username]} diff --git a/netbox/dcim/tests/test_views.py b/netbox/dcim/tests/test_views.py index f5f14ba07..23683ddce 100644 --- a/netbox/dcim/tests/test_views.py +++ b/netbox/dcim/tests/test_views.py @@ -22,6 +22,9 @@ from utilities.testing import ViewTestCases, create_tags, create_test_device, po from wireless.models import WirelessLAN +User = get_user_model() + + class RegionTestCase(ViewTestCases.OrganizationalObjectViewTestCase): model = Region @@ -288,8 +291,8 @@ class RackReservationTestCase(ViewTestCases.PrimaryObjectViewTestCase): @classmethod def setUpTestData(cls): - user2 = get_user_model().objects.create_user(username='testuser2') - user3 = get_user_model().objects.create_user(username='testuser3') + user2 = User.objects.create_user(username='testuser2') + user3 = User.objects.create_user(username='testuser3') site = Site.objects.create(name='Site 1', slug='site-1') diff --git a/netbox/extras/management/commands/runscript.py b/netbox/extras/management/commands/runscript.py index b086b542e..d9a9f41ae 100644 --- a/netbox/extras/management/commands/runscript.py +++ b/netbox/extras/management/commands/runscript.py @@ -63,6 +63,8 @@ class Command(BaseCommand): logger.info(f"Script completed in {job.duration}") + User = get_user_model() + # Params script = options['script'] loglevel = options['loglevel'] @@ -78,11 +80,11 @@ class Command(BaseCommand): # Take user from command line if provided and exists, other if options['user']: try: - user = get_user_model().objects.get(username=options['user']) - except get_user_model().DoesNotExist: - user = get_user_model().objects.filter(is_superuser=True).order_by('pk')[0] + user = User.objects.get(username=options['user']) + except User.DoesNotExist: + user = User.objects.filter(is_superuser=True).order_by('pk')[0] else: - user = get_user_model().objects.filter(is_superuser=True).order_by('pk')[0] + user = User.objects.filter(is_superuser=True).order_by('pk')[0] # Setup logging to Stdout formatter = logging.Formatter(f'[%(asctime)s][%(levelname)s] - %(message)s') @@ -113,7 +115,7 @@ class Command(BaseCommand): job = Job.objects.create( object=module, name=script.name, - user=get_user_model().objects.filter(is_superuser=True).order_by('pk')[0], + user=User.objects.filter(is_superuser=True).order_by('pk')[0], job_id=uuid.uuid4() ) diff --git a/netbox/extras/tests/test_filtersets.py b/netbox/extras/tests/test_filtersets.py index 8c0173727..992643530 100644 --- a/netbox/extras/tests/test_filtersets.py +++ b/netbox/extras/tests/test_filtersets.py @@ -18,6 +18,9 @@ from utilities.testing import BaseFilterSetTests, ChangeLoggedFilterSetTests, cr from virtualization.models import Cluster, ClusterGroup, ClusterType +User = get_user_model() + + class CustomFieldTestCase(TestCase, BaseFilterSetTests): queryset = CustomField.objects.all() filterset = CustomFieldFilterSet @@ -278,11 +281,11 @@ class SavedFilterTestCase(TestCase, BaseFilterSetTests): content_types = ContentType.objects.filter(model__in=['site', 'rack', 'device']) users = ( - get_user_model()(username='User 1'), - get_user_model()(username='User 2'), - get_user_model()(username='User 3'), + User(username='User 1'), + User(username='User 2'), + User(username='User 3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) saved_filters = ( SavedFilter( @@ -332,7 +335,7 @@ class SavedFilterTestCase(TestCase, BaseFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) def test_user(self): - users = get_user_model().objects.filter(username__startswith='User') + users = User.objects.filter(username__startswith='User') params = {'user': [users[0].username, users[1].username]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user_id': [users[0].pk, users[1].pk]} @@ -493,11 +496,11 @@ class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests): Rack.objects.bulk_create(racks) users = ( - get_user_model()(username='Alice'), - get_user_model()(username='Bob'), - get_user_model()(username='Charlie'), + User(username='Alice'), + User(username='Bob'), + User(username='Charlie'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) journal_entries = ( JournalEntry( @@ -540,7 +543,7 @@ class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests): JournalEntry.objects.bulk_create(journal_entries) def test_created_by(self): - users = get_user_model().objects.filter(username__in=['Alice', 'Bob']) + users = User.objects.filter(username__in=['Alice', 'Bob']) params = {'created_by': [users[0].username, users[1].username]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) params = {'created_by_id': [users[0].pk, users[1].pk]} @@ -865,11 +868,11 @@ class ObjectChangeTestCase(TestCase, BaseFilterSetTests): @classmethod def setUpTestData(cls): users = ( - get_user_model()(username='user1'), - get_user_model()(username='user2'), - get_user_model()(username='user3'), + User(username='user1'), + User(username='user2'), + User(username='user3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) site = Site.objects.create(name='Test Site 1', slug='test-site-1') ipaddress = IPAddress.objects.create(address='192.0.2.1/24') @@ -933,7 +936,7 @@ class ObjectChangeTestCase(TestCase, BaseFilterSetTests): ObjectChange.objects.bulk_create(object_changes) def test_user(self): - params = {'user_id': get_user_model().objects.filter(username__in=['user1', 'user2']).values_list('pk', flat=True)} + params = {'user_id': User.objects.filter(username__in=['user1', 'user2']).values_list('pk', flat=True)} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) params = {'user': ['user1', 'user2']} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) diff --git a/netbox/extras/tests/test_views.py b/netbox/extras/tests/test_views.py index 1a3ce8923..3dcb90875 100644 --- a/netbox/extras/tests/test_views.py +++ b/netbox/extras/tests/test_views.py @@ -11,6 +11,9 @@ from extras.models import * from utilities.testing import ViewTestCases, TestCase +User = get_user_model() + + class CustomFieldTestCase(ViewTestCases.PrimaryObjectViewTestCase): model = CustomField @@ -115,11 +118,11 @@ class SavedFilterTestCase(ViewTestCases.PrimaryObjectViewTestCase): site_ct = ContentType.objects.get_for_model(Site) users = ( - get_user_model()(username='User 1'), - get_user_model()(username='User 2'), - get_user_model()(username='User 3'), + User(username='User 1'), + User(username='User 2'), + User(username='User 3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) saved_filters = ( SavedFilter( @@ -412,7 +415,7 @@ class ObjectChangeTestCase(TestCase): site.save() # Create three ObjectChanges - user = get_user_model().objects.create_user(username='testuser2') + user = User.objects.create_user(username='testuser2') for i in range(1, 4): oc = site.to_objectchange(action=ObjectChangeActionChoices.ACTION_UPDATE) oc.user = user @@ -423,7 +426,7 @@ class ObjectChangeTestCase(TestCase): url = reverse('extras:objectchange_list') params = { - "user": get_user_model().objects.first().pk, + "user": User.objects.first().pk, } response = self.client.get('{}?{}'.format(url, urllib.parse.urlencode(params))) @@ -452,7 +455,7 @@ class JournalEntryTestCase( site_ct = ContentType.objects.get_for_model(Site) site = Site.objects.create(name='Site 1', slug='site-1') - user = get_user_model().objects.create(username='User 1') + user = User.objects.create(username='User 1') JournalEntry.objects.bulk_create(( JournalEntry(assigned_object=site, created_by=user, comments='First entry'), diff --git a/netbox/netbox/tests/test_authentication.py b/netbox/netbox/tests/test_authentication.py index d14cc6973..1804087d1 100644 --- a/netbox/netbox/tests/test_authentication.py +++ b/netbox/netbox/tests/test_authentication.py @@ -17,6 +17,9 @@ from utilities.testing import TestCase from utilities.testing.api import APITestCase +User = get_user_model() + + class TokenAuthenticationTestCase(APITestCase): @override_settings(LOGIN_REQUIRED=True, EXEMPT_VIEW_PERMISSIONS=['*']) @@ -88,7 +91,7 @@ class ExternalAuthenticationTestCase(TestCase): @classmethod def setUpTestData(cls): - cls.user = get_user_model().objects.create(username='remoteuser1') + cls.user = User.objects.create(username='remoteuser1') def setUp(self): self.client = Client() @@ -170,7 +173,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - self.user = get_user_model().objects.get(username='remoteuser1') + self.user = User.objects.get(username='remoteuser1') self.assertEqual(self.user.first_name, "John", msg='User first name was not updated') self.assertEqual(self.user.last_name, "Smith", msg='User last name was not updated') self.assertEqual(self.user.email, "johnsmith@example.com", msg='User email was not updated') @@ -196,7 +199,7 @@ class ExternalAuthenticationTestCase(TestCase): self.assertEqual(response.status_code, 200) # Local user should have been automatically created - new_user = get_user_model().objects.get(username='remoteuser2') + new_user = User.objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') @@ -231,7 +234,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = get_user_model().objects.get(username='remoteuser2') + new_user = User.objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertListEqual( @@ -263,7 +266,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = get_user_model().objects.get(username='remoteuser2') + new_user = User.objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertTrue(new_user.has_perms( @@ -303,7 +306,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = get_user_model().objects.get(username='remoteuser2') + new_user = User.objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertListEqual( @@ -344,7 +347,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse("home"), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = get_user_model().objects.get(username="remoteuser2") + new_user = User.objects.get(username="remoteuser2") self.assertEqual( int(self.client.session.get("_auth_user_id")), new_user.pk, @@ -390,7 +393,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = get_user_model().objects.get(username='remoteuser2') + new_user = User.objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertListEqual( @@ -429,7 +432,7 @@ class ObjectPermissionAPIViewTestCase(TestCase): """ Create a test user and token for API calls. """ - self.user = get_user_model().objects.create(username='testuser') + self.user = User.objects.create(username='testuser') self.token = Token.objects.create(user=self.user) self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} diff --git a/netbox/users/tests/test_api.py b/netbox/users/tests/test_api.py index 7005db04e..2de243775 100644 --- a/netbox/users/tests/test_api.py +++ b/netbox/users/tests/test_api.py @@ -8,6 +8,9 @@ from utilities.testing import APIViewTestCases, APITestCase from utilities.utils import deepmerge +User = get_user_model() + + class AppTest(APITestCase): def test_root(self): @@ -19,7 +22,7 @@ class AppTest(APITestCase): class UserTest(APIViewTestCases.APIViewTestCase): - model = get_user_model() + model = User view_namespace = 'users' brief_fields = ['display', 'id', 'url', 'username'] validation_excluded_fields = ['password'] @@ -45,11 +48,11 @@ class UserTest(APIViewTestCases.APIViewTestCase): def setUpTestData(cls): users = ( - get_user_model()(username='User_1', password='password1'), - get_user_model()(username='User_2', password='password2'), - get_user_model()(username='User_3', password='password3'), + User(username='User_1', password='password1'), + User(username='User_2', password='password2'), + User(username='User_3', password='password3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) class GroupTest(APIViewTestCases.APIViewTestCase): @@ -131,7 +134,7 @@ class TokenTest( 'username': 'user1', 'password': 'abc123', } - user = get_user_model().objects.create_user(**data) + user = User.objects.create_user(**data) url = reverse('users-api:token_provision') response = self.client.post(url, data, format='json', **self.header) @@ -159,7 +162,7 @@ class TokenTest( Test provisioning a Token for a different User with & without the grant_token permission. """ self.add_permissions('users.add_token') - user2 = get_user_model().objects.create_user(username='testuser2') + user2 = User.objects.create_user(username='testuser2') data = { 'user': user2.id, } @@ -197,11 +200,11 @@ class ObjectPermissionTest( Group.objects.bulk_create(groups) users = ( - get_user_model()(username='User 1', is_active=True), - get_user_model()(username='User 2', is_active=True), - get_user_model()(username='User 3', is_active=True), + User(username='User 1', is_active=True), + User(username='User 2', is_active=True), + User(username='User 3', is_active=True), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) object_type = ContentType.objects.get(app_label='dcim', model='device') diff --git a/netbox/users/tests/test_filtersets.py b/netbox/users/tests/test_filtersets.py index aacba90e9..d632687ef 100644 --- a/netbox/users/tests/test_filtersets.py +++ b/netbox/users/tests/test_filtersets.py @@ -11,8 +11,11 @@ from users.models import ObjectPermission, Token from utilities.testing import BaseFilterSetTests +User = get_user_model() + + class UserTestCase(TestCase, BaseFilterSetTests): - queryset = get_user_model().objects.all() + queryset = User.objects.all() filterset = filtersets.UserFilterSet @classmethod @@ -26,39 +29,39 @@ class UserTestCase(TestCase, BaseFilterSetTests): Group.objects.bulk_create(groups) users = ( - get_user_model()( + User( username='User1', first_name='Hank', last_name='Hill', email='hank@stricklandpropane.com', is_staff=True ), - get_user_model()( + User( username='User2', first_name='Dale', last_name='Gribble', email='dale@dalesdeadbug.com' ), - get_user_model()( + User( username='User3', first_name='Bill', last_name='Dauterive', email='bill.dauterive@army.mil' ), - get_user_model()( + User( username='User4', first_name='Jeff', last_name='Boomhauer', email='boomhauer@dangolemail.com' ), - get_user_model()( + User( username='User5', first_name='Debbie', last_name='Grund', is_active=False ) ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) users[0].groups.set([groups[0]]) users[1].groups.set([groups[1]]) @@ -130,11 +133,11 @@ class ObjectPermissionTestCase(TestCase, BaseFilterSetTests): Group.objects.bulk_create(groups) users = ( - get_user_model()(username='User1'), - get_user_model()(username='User2'), - get_user_model()(username='User3'), + User(username='User1'), + User(username='User2'), + User(username='User3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) object_types = ( ContentType.objects.get(app_label='dcim', model='site'), @@ -173,7 +176,7 @@ class ObjectPermissionTestCase(TestCase, BaseFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_user(self): - users = get_user_model().objects.filter(username__in=['User1', 'User2']) + users = User.objects.filter(username__in=['User1', 'User2']) params = {'user_id': [users[0].pk, users[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user': [users[0].username, users[1].username]} @@ -197,11 +200,11 @@ class TokenTestCase(TestCase, BaseFilterSetTests): def setUpTestData(cls): users = ( - get_user_model()(username='User1'), - get_user_model()(username='User2'), - get_user_model()(username='User3'), + User(username='User1'), + User(username='User2'), + User(username='User3'), ) - get_user_model().objects.bulk_create(users) + User.objects.bulk_create(users) future_date = make_aware(datetime.datetime(3000, 1, 1)) past_date = make_aware(datetime.datetime(2000, 1, 1)) @@ -213,7 +216,7 @@ class TokenTestCase(TestCase, BaseFilterSetTests): Token.objects.bulk_create(tokens) def test_user(self): - users = get_user_model().objects.order_by('id')[:2] + users = User.objects.order_by('id')[:2] params = {'user_id': [users[0].pk, users[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user': [users[0].username, users[1].username]} diff --git a/netbox/users/tests/test_models.py b/netbox/users/tests/test_models.py index 27146c71b..791ea8fb4 100644 --- a/netbox/users/tests/test_models.py +++ b/netbox/users/tests/test_models.py @@ -2,12 +2,15 @@ from django.contrib.auth import get_user_model from django.test import TestCase +User = get_user_model() + + class UserConfigTest(TestCase): @classmethod def setUpTestData(cls): - user = get_user_model().objects.create_user(username='testuser') + user = User.objects.create_user(username='testuser') user.config.data = { 'a': True, 'b': { @@ -29,7 +32,7 @@ class UserConfigTest(TestCase): user.config.save() def test_get(self): - userconfig = get_user_model().objects.get(username='testuser').config + userconfig = User.objects.get(username='testuser').config # Retrieve root and nested values self.assertEqual(userconfig.get('a'), True) @@ -49,7 +52,7 @@ class UserConfigTest(TestCase): self.assertEqual(userconfig.get('b.foo.x.invalid', 'DEFAULT'), 'DEFAULT') def test_all(self): - userconfig = get_user_model().objects.get(username='testuser').config + userconfig = User.objects.get(username='testuser').config flattened_data = { 'a': True, 'b.foo': 101, @@ -63,7 +66,7 @@ class UserConfigTest(TestCase): self.assertEqual(userconfig.all(), flattened_data) def test_set(self): - userconfig = get_user_model().objects.get(username='testuser').config + userconfig = User.objects.get(username='testuser').config # Overwrite existing values userconfig.set('a', 'abc') @@ -92,7 +95,7 @@ class UserConfigTest(TestCase): userconfig.set('a.x', 1) def test_clear(self): - userconfig = get_user_model().objects.get(username='testuser').config + userconfig = User.objects.get(username='testuser').config # Clear existing values userconfig.clear('a')