diff --git a/netbox/core/forms/filtersets.py b/netbox/core/forms/filtersets.py index 7c3f2ab09..d8624f6b6 100644 --- a/netbox/core/forms/filtersets.py +++ b/netbox/core/forms/filtersets.py @@ -1,5 +1,5 @@ from django import forms -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.utils.translation import gettext as _ @@ -105,7 +105,7 @@ class JobFilterForm(SavedFiltersMixin, FilterForm): widget=DateTimePicker() ) user = DynamicModelMultipleChoiceField( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), required=False, label=_('User'), widget=APISelectMultiple( diff --git a/netbox/core/management/commands/nbshell.py b/netbox/core/management/commands/nbshell.py index 04a67eb49..674a878c7 100644 --- a/netbox/core/management/commands/nbshell.py +++ b/netbox/core/management/commands/nbshell.py @@ -5,7 +5,7 @@ import sys from django import get_version from django.apps import apps from django.conf import settings -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand @@ -60,7 +60,7 @@ class Command(BaseCommand): # Additional objects to include namespace['ContentType'] = ContentType - namespace['User'] = User + namespace['User'] = get_user_model() # Load convenience commands namespace.update({ diff --git a/netbox/core/models/jobs.py b/netbox/core/models/jobs.py index a91e75e61..0715a4521 100644 --- a/netbox/core/models/jobs.py +++ b/netbox/core/models/jobs.py @@ -1,7 +1,7 @@ import uuid import django_rq -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.core.validators import MinValueValidator @@ -69,7 +69,7 @@ class Job(models.Model): blank=True ) user = models.ForeignKey( - to=User, + to=get_user_model(), on_delete=models.SET_NULL, related_name='+', blank=True, diff --git a/netbox/dcim/filtersets.py b/netbox/dcim/filtersets.py index e87a37847..e53ea8079 100644 --- a/netbox/dcim/filtersets.py +++ b/netbox/dcim/filtersets.py @@ -1,5 +1,5 @@ import django_filters -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.utils.translation import gettext as _ from extras.filtersets import LocalConfigContextFilterSet @@ -395,12 +395,12 @@ class RackReservationFilterSet(NetBoxModelFilterSet, TenancyFilterSet): label=_('Location (slug)'), ) user_id = django_filters.ModelMultipleChoiceFilter( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), label=_('User (ID)'), ) user = django_filters.ModelMultipleChoiceFilter( field_name='user__username', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), to_field_name='username', label=_('User (name)'), ) diff --git a/netbox/dcim/forms/bulk_edit.py b/netbox/dcim/forms/bulk_edit.py index 11cfd685d..309370bfd 100644 --- a/netbox/dcim/forms/bulk_edit.py +++ b/netbox/dcim/forms/bulk_edit.py @@ -1,6 +1,6 @@ from django import forms from django.conf import settings -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.utils.translation import gettext as _ from timezone_field import TimeZoneFormField @@ -322,7 +322,7 @@ class RackBulkEditForm(NetBoxModelBulkEditForm): class RackReservationBulkEditForm(NetBoxModelBulkEditForm): user = forms.ModelChoiceField( - queryset=User.objects.order_by( + queryset=get_user_model().objects.order_by( 'username' ), required=False diff --git a/netbox/dcim/forms/filtersets.py b/netbox/dcim/forms/filtersets.py index 4edee6014..0a4a22a70 100644 --- a/netbox/dcim/forms/filtersets.py +++ b/netbox/dcim/forms/filtersets.py @@ -1,5 +1,5 @@ from django import forms -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.utils.translation import gettext as _ from dcim.choices import * @@ -376,7 +376,7 @@ class RackReservationFilterForm(TenancyFilterForm, NetBoxModelFilterSetForm): label=_('Rack') ) user_id = DynamicModelMultipleChoiceField( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), required=False, label=_('User'), widget=APISelectMultiple( diff --git a/netbox/dcim/forms/model_forms.py b/netbox/dcim/forms/model_forms.py index 56542d70c..eda302736 100644 --- a/netbox/dcim/forms/model_forms.py +++ b/netbox/dcim/forms/model_forms.py @@ -1,5 +1,5 @@ from django import forms -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.utils.translation import gettext as _ from timezone_field import TimeZoneFormField @@ -236,7 +236,7 @@ class RackReservationForm(TenancyForm, NetBoxModelForm): help_text=_("Comma-separated list of numeric unit IDs. A range may be specified using a hyphen.") ) user = forms.ModelChoiceField( - queryset=User.objects.order_by( + queryset=get_user_model().objects.order_by( 'username' ) ) diff --git a/netbox/dcim/models/racks.py b/netbox/dcim/models/racks.py index e5412a3ab..777454dce 100644 --- a/netbox/dcim/models/racks.py +++ b/netbox/dcim/models/racks.py @@ -1,7 +1,7 @@ import decimal from functools import cached_property -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.fields import GenericRelation from django.contrib.postgres.fields import ArrayField from django.core.exceptions import ValidationError @@ -505,7 +505,7 @@ class RackReservation(PrimaryModel): null=True ) user = models.ForeignKey( - to=User, + to=get_user_model(), on_delete=models.PROTECT ) description = models.CharField( diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index af15e1343..235f35192 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -1,4 +1,4 @@ -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.test import override_settings from django.urls import reverse from rest_framework import status @@ -363,7 +363,7 @@ class RackReservationTest(APIViewTestCases.APIViewTestCase): @classmethod def setUpTestData(cls): - user = User.objects.create(username='user1', is_active=True) + user = get_user_model().objects.create(username='user1', is_active=True) site = Site.objects.create(name='Test Site 1', slug='test-site-1') racks = ( diff --git a/netbox/dcim/tests/test_filtersets.py b/netbox/dcim/tests/test_filtersets.py index aa6860a16..f55be2f3f 100644 --- a/netbox/dcim/tests/test_filtersets.py +++ b/netbox/dcim/tests/test_filtersets.py @@ -1,4 +1,4 @@ -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.test import TestCase from dcim.choices import * @@ -593,11 +593,11 @@ class RackReservationTestCase(TestCase, ChangeLoggedFilterSetTests): Rack.objects.bulk_create(racks) users = ( - User(username='User 1'), - User(username='User 2'), - User(username='User 3'), + get_user_model()(username='User 1'), + get_user_model()(username='User 2'), + get_user_model()(username='User 3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) tenant_groups = ( TenantGroup(name='Tenant group 1', slug='tenant-group-1'), @@ -650,7 +650,7 @@ class RackReservationTestCase(TestCase, ChangeLoggedFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_user(self): - users = User.objects.all()[:2] + users = get_user_model().objects.all()[:2] params = {'user_id': [users[0].pk, users[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user': [users[0].username, users[1].username]} diff --git a/netbox/dcim/tests/test_views.py b/netbox/dcim/tests/test_views.py index a327d6400..f5f14ba07 100644 --- a/netbox/dcim/tests/test_views.py +++ b/netbox/dcim/tests/test_views.py @@ -6,7 +6,7 @@ except ImportError: from backports.zoneinfo import ZoneInfo import yaml -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.test import override_settings from django.urls import reverse @@ -288,8 +288,8 @@ class RackReservationTestCase(ViewTestCases.PrimaryObjectViewTestCase): @classmethod def setUpTestData(cls): - user2 = User.objects.create_user(username='testuser2') - user3 = User.objects.create_user(username='testuser3') + user2 = get_user_model().objects.create_user(username='testuser2') + user3 = get_user_model().objects.create_user(username='testuser3') site = Site.objects.create(name='Site 1', slug='site-1') diff --git a/netbox/extras/api/serializers.py b/netbox/extras/api/serializers.py index cbe4ed56d..a02e933ba 100644 --- a/netbox/extras/api/serializers.py +++ b/netbox/extras/api/serializers.py @@ -1,4 +1,4 @@ -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ObjectDoesNotExist from rest_framework import serializers @@ -256,7 +256,7 @@ class JournalEntrySerializer(NetBoxModelSerializer): assigned_object = serializers.SerializerMethodField(read_only=True) created_by = serializers.PrimaryKeyRelatedField( allow_null=True, - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), required=False, default=serializers.CurrentUserDefault() ) diff --git a/netbox/extras/filtersets.py b/netbox/extras/filtersets.py index 5253ae7b0..2cbaca5f7 100644 --- a/netbox/extras/filtersets.py +++ b/netbox/extras/filtersets.py @@ -1,5 +1,5 @@ import django_filters -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.db.models import Q from django.utils.translation import gettext as _ @@ -159,12 +159,12 @@ class SavedFilterFilterSet(BaseFilterSet): ) content_types = ContentTypeFilter() user_id = django_filters.ModelMultipleChoiceFilter( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), label=_('User (ID)'), ) user = django_filters.ModelMultipleChoiceFilter( field_name='user__username', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), to_field_name='username', label=_('User (name)'), ) @@ -223,12 +223,12 @@ class JournalEntryFilterSet(NetBoxModelFilterSet): queryset=ContentType.objects.all() ) created_by_id = django_filters.ModelMultipleChoiceFilter( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), label=_('User (ID)'), ) created_by = django_filters.ModelMultipleChoiceFilter( field_name='created_by__username', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), to_field_name='username', label=_('User (name)'), ) @@ -510,12 +510,12 @@ class ObjectChangeFilterSet(BaseFilterSet): queryset=ContentType.objects.all() ) user_id = django_filters.ModelMultipleChoiceFilter( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), label=_('User (ID)'), ) user = django_filters.ModelMultipleChoiceFilter( field_name='user__username', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), to_field_name='username', label=_('User name'), ) diff --git a/netbox/extras/forms/filtersets.py b/netbox/extras/forms/filtersets.py index fae15d041..53de81ba2 100644 --- a/netbox/extras/forms/filtersets.py +++ b/netbox/extras/forms/filtersets.py @@ -1,5 +1,5 @@ from django import forms -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.utils.translation import gettext as _ @@ -385,7 +385,7 @@ class JournalEntryFilterForm(NetBoxModelFilterSetForm): widget=DateTimePicker() ) created_by_id = DynamicModelMultipleChoiceField( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), required=False, label=_('User'), widget=APISelectMultiple( @@ -429,7 +429,7 @@ class ObjectChangeFilterForm(SavedFiltersMixin, FilterForm): required=False ) user_id = DynamicModelMultipleChoiceField( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), required=False, label=_('User'), widget=APISelectMultiple( diff --git a/netbox/extras/management/commands/runscript.py b/netbox/extras/management/commands/runscript.py index b42e9b47d..b086b542e 100644 --- a/netbox/extras/management/commands/runscript.py +++ b/netbox/extras/management/commands/runscript.py @@ -4,7 +4,7 @@ import sys import traceback import uuid -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.core.management.base import BaseCommand, CommandError from django.db import transaction @@ -78,11 +78,11 @@ class Command(BaseCommand): # Take user from command line if provided and exists, other if options['user']: try: - user = User.objects.get(username=options['user']) - except User.DoesNotExist: - user = User.objects.filter(is_superuser=True).order_by('pk')[0] + user = get_user_model().objects.get(username=options['user']) + except get_user_model().DoesNotExist: + user = get_user_model().objects.filter(is_superuser=True).order_by('pk')[0] else: - user = User.objects.filter(is_superuser=True).order_by('pk')[0] + user = get_user_model().objects.filter(is_superuser=True).order_by('pk')[0] # Setup logging to Stdout formatter = logging.Formatter(f'[%(asctime)s][%(levelname)s] - %(message)s') @@ -113,7 +113,7 @@ class Command(BaseCommand): job = Job.objects.create( object=module, name=script.name, - user=User.objects.filter(is_superuser=True).order_by('pk')[0], + user=get_user_model().objects.filter(is_superuser=True).order_by('pk')[0], job_id=uuid.uuid4() ) diff --git a/netbox/extras/models/change_logging.py b/netbox/extras/models/change_logging.py index e2b118b84..b03ebf475 100644 --- a/netbox/extras/models/change_logging.py +++ b/netbox/extras/models/change_logging.py @@ -1,4 +1,4 @@ -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.db import models @@ -24,7 +24,7 @@ class ObjectChange(models.Model): db_index=True ) user = models.ForeignKey( - to=User, + to=get_user_model(), on_delete=models.SET_NULL, related_name='changes', blank=True, diff --git a/netbox/extras/models/models.py b/netbox/extras/models/models.py index 969fd22e0..bfb13fc71 100644 --- a/netbox/extras/models/models.py +++ b/netbox/extras/models/models.py @@ -3,7 +3,7 @@ import urllib.parse from django.conf import settings from django.contrib import admin -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.core.cache import cache @@ -419,7 +419,7 @@ class SavedFilter(CloningMixin, ExportTemplatesMixin, ChangeLoggedModel): blank=True ) user = models.ForeignKey( - to=User, + to=get_user_model(), on_delete=models.SET_NULL, blank=True, null=True @@ -560,7 +560,7 @@ class JournalEntry(CustomFieldsMixin, CustomLinksMixin, TagsMixin, ExportTemplat fk_field='assigned_object_id' ) created_by = models.ForeignKey( - to=User, + to=get_user_model(), on_delete=models.SET_NULL, blank=True, null=True diff --git a/netbox/extras/tests/test_api.py b/netbox/extras/tests/test_api.py index b59481a36..73e07025d 100644 --- a/netbox/extras/tests/test_api.py +++ b/netbox/extras/tests/test_api.py @@ -1,6 +1,6 @@ import datetime -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.urls import reverse from django.utils.timezone import make_aware @@ -396,7 +396,7 @@ class JournalEntryTest(APIViewTestCases.APIViewTestCase): @classmethod def setUpTestData(cls): - user = User.objects.first() + user = get_user_model().objects.first() site = Site.objects.create(name='Site 1', slug='site-1') journal_entries = ( diff --git a/netbox/extras/tests/test_filtersets.py b/netbox/extras/tests/test_filtersets.py index e77afd20e..8c0173727 100644 --- a/netbox/extras/tests/test_filtersets.py +++ b/netbox/extras/tests/test_filtersets.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime, timezone -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.test import TestCase @@ -278,11 +278,11 @@ class SavedFilterTestCase(TestCase, BaseFilterSetTests): content_types = ContentType.objects.filter(model__in=['site', 'rack', 'device']) users = ( - User(username='User 1'), - User(username='User 2'), - User(username='User 3'), + get_user_model()(username='User 1'), + get_user_model()(username='User 2'), + get_user_model()(username='User 3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) saved_filters = ( SavedFilter( @@ -332,7 +332,7 @@ class SavedFilterTestCase(TestCase, BaseFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) def test_user(self): - users = User.objects.filter(username__startswith='User') + users = get_user_model().objects.filter(username__startswith='User') params = {'user': [users[0].username, users[1].username]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user_id': [users[0].pk, users[1].pk]} @@ -493,11 +493,11 @@ class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests): Rack.objects.bulk_create(racks) users = ( - User(username='Alice'), - User(username='Bob'), - User(username='Charlie'), + get_user_model()(username='Alice'), + get_user_model()(username='Bob'), + get_user_model()(username='Charlie'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) journal_entries = ( JournalEntry( @@ -540,7 +540,7 @@ class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests): JournalEntry.objects.bulk_create(journal_entries) def test_created_by(self): - users = User.objects.filter(username__in=['Alice', 'Bob']) + users = get_user_model().objects.filter(username__in=['Alice', 'Bob']) params = {'created_by': [users[0].username, users[1].username]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) params = {'created_by_id': [users[0].pk, users[1].pk]} @@ -865,11 +865,11 @@ class ObjectChangeTestCase(TestCase, BaseFilterSetTests): @classmethod def setUpTestData(cls): users = ( - User(username='user1'), - User(username='user2'), - User(username='user3'), + get_user_model()(username='user1'), + get_user_model()(username='user2'), + get_user_model()(username='user3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) site = Site.objects.create(name='Test Site 1', slug='test-site-1') ipaddress = IPAddress.objects.create(address='192.0.2.1/24') @@ -933,7 +933,7 @@ class ObjectChangeTestCase(TestCase, BaseFilterSetTests): ObjectChange.objects.bulk_create(object_changes) def test_user(self): - params = {'user_id': User.objects.filter(username__in=['user1', 'user2']).values_list('pk', flat=True)} + params = {'user_id': get_user_model().objects.filter(username__in=['user1', 'user2']).values_list('pk', flat=True)} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) params = {'user': ['user1', 'user2']} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4) diff --git a/netbox/extras/tests/test_views.py b/netbox/extras/tests/test_views.py index ef8e87489..1a3ce8923 100644 --- a/netbox/extras/tests/test_views.py +++ b/netbox/extras/tests/test_views.py @@ -1,7 +1,7 @@ import urllib.parse import uuid -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.urls import reverse @@ -115,11 +115,11 @@ class SavedFilterTestCase(ViewTestCases.PrimaryObjectViewTestCase): site_ct = ContentType.objects.get_for_model(Site) users = ( - User(username='User 1'), - User(username='User 2'), - User(username='User 3'), + get_user_model()(username='User 1'), + get_user_model()(username='User 2'), + get_user_model()(username='User 3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) saved_filters = ( SavedFilter( @@ -412,7 +412,7 @@ class ObjectChangeTestCase(TestCase): site.save() # Create three ObjectChanges - user = User.objects.create_user(username='testuser2') + user = get_user_model().objects.create_user(username='testuser2') for i in range(1, 4): oc = site.to_objectchange(action=ObjectChangeActionChoices.ACTION_UPDATE) oc.user = user @@ -423,7 +423,7 @@ class ObjectChangeTestCase(TestCase): url = reverse('extras:objectchange_list') params = { - "user": User.objects.first().pk, + "user": get_user_model().objects.first().pk, } response = self.client.get('{}?{}'.format(url, urllib.parse.urlencode(params))) @@ -452,7 +452,7 @@ class JournalEntryTestCase( site_ct = ContentType.objects.get_for_model(Site) site = Site.objects.create(name='Site 1', slug='site-1') - user = User.objects.create(username='User 1') + user = get_user_model().objects.create(username='User 1') JournalEntry.objects.bulk_create(( JournalEntry(assigned_object=site, created_by=user, comments='First entry'), diff --git a/netbox/netbox/tests/test_authentication.py b/netbox/netbox/tests/test_authentication.py index 4e46996b5..d14cc6973 100644 --- a/netbox/netbox/tests/test_authentication.py +++ b/netbox/netbox/tests/test_authentication.py @@ -1,7 +1,8 @@ import datetime from django.conf import settings -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.contenttypes.models import ContentType from django.test import Client from django.test.utils import override_settings @@ -87,7 +88,7 @@ class ExternalAuthenticationTestCase(TestCase): @classmethod def setUpTestData(cls): - cls.user = User.objects.create(username='remoteuser1') + cls.user = get_user_model().objects.create(username='remoteuser1') def setUp(self): self.client = Client() @@ -169,7 +170,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - self.user = User.objects.get(username='remoteuser1') + self.user = get_user_model().objects.get(username='remoteuser1') self.assertEqual(self.user.first_name, "John", msg='User first name was not updated') self.assertEqual(self.user.last_name, "Smith", msg='User last name was not updated') self.assertEqual(self.user.email, "johnsmith@example.com", msg='User email was not updated') @@ -195,7 +196,7 @@ class ExternalAuthenticationTestCase(TestCase): self.assertEqual(response.status_code, 200) # Local user should have been automatically created - new_user = User.objects.get(username='remoteuser2') + new_user = get_user_model().objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') @@ -230,7 +231,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = User.objects.get(username='remoteuser2') + new_user = get_user_model().objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertListEqual( @@ -262,7 +263,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = User.objects.get(username='remoteuser2') + new_user = get_user_model().objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertTrue(new_user.has_perms( @@ -302,7 +303,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = User.objects.get(username='remoteuser2') + new_user = get_user_model().objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertListEqual( @@ -343,7 +344,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse("home"), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = User.objects.get(username="remoteuser2") + new_user = get_user_model().objects.get(username="remoteuser2") self.assertEqual( int(self.client.session.get("_auth_user_id")), new_user.pk, @@ -389,7 +390,7 @@ class ExternalAuthenticationTestCase(TestCase): response = self.client.get(reverse('home'), follow=True, **headers) self.assertEqual(response.status_code, 200) - new_user = User.objects.get(username='remoteuser2') + new_user = get_user_model().objects.get(username='remoteuser2') self.assertEqual(int(self.client.session.get( '_auth_user_id')), new_user.pk, msg='Authentication failed') self.assertListEqual( @@ -428,7 +429,7 @@ class ObjectPermissionAPIViewTestCase(TestCase): """ Create a test user and token for API calls. """ - self.user = User.objects.create(username='testuser') + self.user = get_user_model().objects.create(username='testuser') self.token = Token.objects.create(user=self.user) self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} diff --git a/netbox/users/api/nested_serializers.py b/netbox/users/api/nested_serializers.py index 3510184ae..5e15fa41a 100644 --- a/netbox/users/api/nested_serializers.py +++ b/netbox/users/api/nested_serializers.py @@ -1,4 +1,5 @@ -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.contenttypes.models import ContentType from drf_spectacular.utils import extend_schema_field from drf_spectacular.types import OpenApiTypes @@ -28,7 +29,7 @@ class NestedUserSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='users-api:user-detail') class Meta: - model = User + model = get_user_model() fields = ['id', 'url', 'display', 'username'] @extend_schema_field(OpenApiTypes.STR) diff --git a/netbox/users/api/serializers.py b/netbox/users/api/serializers.py index 1b975791f..1f4bf4ea0 100644 --- a/netbox/users/api/serializers.py +++ b/netbox/users/api/serializers.py @@ -1,5 +1,6 @@ from django.conf import settings -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.contenttypes.models import ContentType from drf_spectacular.utils import extend_schema_field from drf_spectacular.types import OpenApiTypes @@ -30,7 +31,7 @@ class UserSerializer(ValidatedModelSerializer): ) class Meta: - model = User + model = get_user_model() fields = ( 'id', 'url', 'display', 'username', 'password', 'first_name', 'last_name', 'email', 'is_staff', 'is_active', 'date_joined', 'groups', @@ -124,7 +125,7 @@ class ObjectPermissionSerializer(ValidatedModelSerializer): many=True ) users = SerializedPKRelatedField( - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), serializer=NestedUserSerializer, required=False, many=True diff --git a/netbox/users/api/views.py b/netbox/users/api/views.py index 04b3ae336..4a8e1b154 100644 --- a/netbox/users/api/views.py +++ b/netbox/users/api/views.py @@ -1,5 +1,6 @@ from django.contrib.auth import authenticate -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.db.models import Count from drf_spectacular.utils import extend_schema from drf_spectacular.types import OpenApiTypes @@ -32,7 +33,7 @@ class UsersRootView(APIRootView): # class UserViewSet(NetBoxModelViewSet): - queryset = RestrictedQuerySet(model=User).prefetch_related('groups').order_by('username') + queryset = RestrictedQuerySet(model=get_user_model()).prefetch_related('groups').order_by('username') serializer_class = serializers.UserSerializer filterset_class = filtersets.UserFilterSet diff --git a/netbox/users/filtersets.py b/netbox/users/filtersets.py index 4ae9df89a..44ad98cc2 100644 --- a/netbox/users/filtersets.py +++ b/netbox/users/filtersets.py @@ -1,5 +1,6 @@ import django_filters -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.db.models import Q from django.utils.translation import gettext as _ @@ -47,7 +48,7 @@ class UserFilterSet(BaseFilterSet): ) class Meta: - model = User + model = get_user_model() fields = ['id', 'username', 'first_name', 'last_name', 'email', 'is_staff', 'is_active'] def search(self, queryset, name, value): @@ -68,12 +69,12 @@ class TokenFilterSet(BaseFilterSet): ) user_id = django_filters.ModelMultipleChoiceFilter( field_name='user', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), label=_('User'), ) user = django_filters.ModelMultipleChoiceFilter( field_name='user__username', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), to_field_name='username', label=_('User (name)'), ) @@ -116,12 +117,12 @@ class ObjectPermissionFilterSet(BaseFilterSet): ) user_id = django_filters.ModelMultipleChoiceFilter( field_name='users', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), label=_('User'), ) user = django_filters.ModelMultipleChoiceFilter( field_name='users__username', - queryset=User.objects.all(), + queryset=get_user_model().objects.all(), to_field_name='username', label=_('User (name)'), ) diff --git a/netbox/users/graphql/schema.py b/netbox/users/graphql/schema.py index 3b04d8418..f033a535a 100644 --- a/netbox/users/graphql/schema.py +++ b/netbox/users/graphql/schema.py @@ -1,6 +1,7 @@ import graphene -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from netbox.graphql.fields import ObjectField, ObjectListField from .types import * from utilities.graphql_optimizer import gql_query_optimizer @@ -17,4 +18,4 @@ class UsersQuery(graphene.ObjectType): user_list = ObjectListField(UserType) def resolve_user_list(root, info, **kwargs): - return gql_query_optimizer(User.objects.all(), info) + return gql_query_optimizer(get_user_model().objects.all(), info) diff --git a/netbox/users/graphql/types.py b/netbox/users/graphql/types.py index d948686c6..4254f1791 100644 --- a/netbox/users/graphql/types.py +++ b/netbox/users/graphql/types.py @@ -1,4 +1,5 @@ -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from graphene_django import DjangoObjectType from users import filtersets @@ -25,7 +26,7 @@ class GroupType(DjangoObjectType): class UserType(DjangoObjectType): class Meta: - model = User + model = get_user_model() fields = ( 'id', 'username', 'password', 'first_name', 'last_name', 'email', 'is_staff', 'is_active', 'date_joined', 'groups', @@ -34,4 +35,4 @@ class UserType(DjangoObjectType): @classmethod def get_queryset(cls, queryset, info): - return RestrictedQuerySet(model=User).restrict(info.context.user, 'view') + return RestrictedQuerySet(model=get_user_model()).restrict(info.context.user, 'view') diff --git a/netbox/users/tests/test_api.py b/netbox/users/tests/test_api.py index 281f656d2..7005db04e 100644 --- a/netbox/users/tests/test_api.py +++ b/netbox/users/tests/test_api.py @@ -1,4 +1,5 @@ -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.contenttypes.models import ContentType from django.urls import reverse @@ -18,7 +19,7 @@ class AppTest(APITestCase): class UserTest(APIViewTestCases.APIViewTestCase): - model = User + model = get_user_model() view_namespace = 'users' brief_fields = ['display', 'id', 'url', 'username'] validation_excluded_fields = ['password'] @@ -44,11 +45,11 @@ class UserTest(APIViewTestCases.APIViewTestCase): def setUpTestData(cls): users = ( - User(username='User_1', password='password1'), - User(username='User_2', password='password2'), - User(username='User_3', password='password3'), + get_user_model()(username='User_1', password='password1'), + get_user_model()(username='User_2', password='password2'), + get_user_model()(username='User_3', password='password3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) class GroupTest(APIViewTestCases.APIViewTestCase): @@ -130,7 +131,7 @@ class TokenTest( 'username': 'user1', 'password': 'abc123', } - user = User.objects.create_user(**data) + user = get_user_model().objects.create_user(**data) url = reverse('users-api:token_provision') response = self.client.post(url, data, format='json', **self.header) @@ -158,7 +159,7 @@ class TokenTest( Test provisioning a Token for a different User with & without the grant_token permission. """ self.add_permissions('users.add_token') - user2 = User.objects.create_user(username='testuser2') + user2 = get_user_model().objects.create_user(username='testuser2') data = { 'user': user2.id, } @@ -196,11 +197,11 @@ class ObjectPermissionTest( Group.objects.bulk_create(groups) users = ( - User(username='User 1', is_active=True), - User(username='User 2', is_active=True), - User(username='User 3', is_active=True), + get_user_model()(username='User 1', is_active=True), + get_user_model()(username='User 2', is_active=True), + get_user_model()(username='User 3', is_active=True), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) object_type = ContentType.objects.get(app_label='dcim', model='device') diff --git a/netbox/users/tests/test_filtersets.py b/netbox/users/tests/test_filtersets.py index 33ed7e7ba..aacba90e9 100644 --- a/netbox/users/tests/test_filtersets.py +++ b/netbox/users/tests/test_filtersets.py @@ -1,6 +1,7 @@ import datetime -from django.contrib.auth.models import Group, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.contenttypes.models import ContentType from django.test import TestCase from django.utils.timezone import make_aware @@ -11,7 +12,7 @@ from utilities.testing import BaseFilterSetTests class UserTestCase(TestCase, BaseFilterSetTests): - queryset = User.objects.all() + queryset = get_user_model().objects.all() filterset = filtersets.UserFilterSet @classmethod @@ -25,39 +26,39 @@ class UserTestCase(TestCase, BaseFilterSetTests): Group.objects.bulk_create(groups) users = ( - User( + get_user_model()( username='User1', first_name='Hank', last_name='Hill', email='hank@stricklandpropane.com', is_staff=True ), - User( + get_user_model()( username='User2', first_name='Dale', last_name='Gribble', email='dale@dalesdeadbug.com' ), - User( + get_user_model()( username='User3', first_name='Bill', last_name='Dauterive', email='bill.dauterive@army.mil' ), - User( + get_user_model()( username='User4', first_name='Jeff', last_name='Boomhauer', email='boomhauer@dangolemail.com' ), - User( + get_user_model()( username='User5', first_name='Debbie', last_name='Grund', is_active=False ) ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) users[0].groups.set([groups[0]]) users[1].groups.set([groups[1]]) @@ -129,11 +130,11 @@ class ObjectPermissionTestCase(TestCase, BaseFilterSetTests): Group.objects.bulk_create(groups) users = ( - User(username='User1'), - User(username='User2'), - User(username='User3'), + get_user_model()(username='User1'), + get_user_model()(username='User2'), + get_user_model()(username='User3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) object_types = ( ContentType.objects.get(app_label='dcim', model='site'), @@ -172,7 +173,7 @@ class ObjectPermissionTestCase(TestCase, BaseFilterSetTests): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) def test_user(self): - users = User.objects.filter(username__in=['User1', 'User2']) + users = get_user_model().objects.filter(username__in=['User1', 'User2']) params = {'user_id': [users[0].pk, users[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user': [users[0].username, users[1].username]} @@ -196,11 +197,11 @@ class TokenTestCase(TestCase, BaseFilterSetTests): def setUpTestData(cls): users = ( - User(username='User1'), - User(username='User2'), - User(username='User3'), + get_user_model()(username='User1'), + get_user_model()(username='User2'), + get_user_model()(username='User3'), ) - User.objects.bulk_create(users) + get_user_model().objects.bulk_create(users) future_date = make_aware(datetime.datetime(3000, 1, 1)) past_date = make_aware(datetime.datetime(2000, 1, 1)) @@ -212,7 +213,7 @@ class TokenTestCase(TestCase, BaseFilterSetTests): Token.objects.bulk_create(tokens) def test_user(self): - users = User.objects.order_by('id')[:2] + users = get_user_model().objects.order_by('id')[:2] params = {'user_id': [users[0].pk, users[1].pk]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'user': [users[0].username, users[1].username]} diff --git a/netbox/users/tests/test_models.py b/netbox/users/tests/test_models.py index 7a2337f33..27146c71b 100644 --- a/netbox/users/tests/test_models.py +++ b/netbox/users/tests/test_models.py @@ -1,4 +1,4 @@ -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.test import TestCase @@ -7,7 +7,7 @@ class UserConfigTest(TestCase): @classmethod def setUpTestData(cls): - user = User.objects.create_user(username='testuser') + user = get_user_model().objects.create_user(username='testuser') user.config.data = { 'a': True, 'b': { @@ -29,7 +29,7 @@ class UserConfigTest(TestCase): user.config.save() def test_get(self): - userconfig = User.objects.get(username='testuser').config + userconfig = get_user_model().objects.get(username='testuser').config # Retrieve root and nested values self.assertEqual(userconfig.get('a'), True) @@ -49,7 +49,7 @@ class UserConfigTest(TestCase): self.assertEqual(userconfig.get('b.foo.x.invalid', 'DEFAULT'), 'DEFAULT') def test_all(self): - userconfig = User.objects.get(username='testuser').config + userconfig = get_user_model().objects.get(username='testuser').config flattened_data = { 'a': True, 'b.foo': 101, @@ -63,7 +63,7 @@ class UserConfigTest(TestCase): self.assertEqual(userconfig.all(), flattened_data) def test_set(self): - userconfig = User.objects.get(username='testuser').config + userconfig = get_user_model().objects.get(username='testuser').config # Overwrite existing values userconfig.set('a', 'abc') @@ -92,7 +92,7 @@ class UserConfigTest(TestCase): userconfig.set('a.x', 1) def test_clear(self): - userconfig = User.objects.get(username='testuser').config + userconfig = get_user_model().objects.get(username='testuser').config # Clear existing values userconfig.clear('a') diff --git a/netbox/users/tests/test_preferences.py b/netbox/users/tests/test_preferences.py index f1e947d67..2776e6344 100644 --- a/netbox/users/tests/test_preferences.py +++ b/netbox/users/tests/test_preferences.py @@ -1,4 +1,4 @@ -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.test import override_settings from django.test.client import RequestFactory from django.urls import reverse @@ -39,7 +39,7 @@ class UserPreferencesTest(TestCase): @override_settings(DEFAULT_USER_PREFERENCES=DEFAULT_USER_PREFERENCES) def test_default_preferences(self): - user = User.objects.create(username='User 1') + user = get_user_model().objects.create(username='User 1') userconfig = user.config self.assertEqual(userconfig.data, DEFAULT_USER_PREFERENCES) diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 7f24c86b8..816037b0c 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -2,7 +2,7 @@ import inspect import json from django.conf import settings -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.urls import reverse from django.test import override_settings @@ -45,7 +45,7 @@ class APITestCase(ModelTestCase): Create a user and token for API calls. """ # Create the test user and assign permissions - self.user = User.objects.create_user(username='testuser') + self.user = get_user_model().objects.create_user(username='testuser') self.add_permissions(*self.user_permissions) self.token = Token.objects.create(user=self.user) self.header = {'HTTP_AUTHORIZATION': f'Token {self.token.key}'} diff --git a/netbox/utilities/testing/base.py b/netbox/utilities/testing/base.py index 04ceca1e2..76a9fac06 100644 --- a/netbox/utilities/testing/base.py +++ b/netbox/utilities/testing/base.py @@ -1,6 +1,6 @@ import json -from django.contrib.auth.models import User +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.contrib.postgres.fields import ArrayField from django.core.exceptions import FieldDoesNotExist @@ -27,7 +27,7 @@ class TestCase(_TestCase): def setUp(self): # Create the test user and assign permissions - self.user = User.objects.create_user(username='testuser') + self.user = get_user_model().objects.create_user(username='testuser') self.add_permissions(*self.user_permissions) # Initialize the test client diff --git a/netbox/utilities/testing/utils.py b/netbox/utilities/testing/utils.py index 52ccd002d..87fc3319c 100644 --- a/netbox/utilities/testing/utils.py +++ b/netbox/utilities/testing/utils.py @@ -2,7 +2,8 @@ import logging import re from contextlib import contextmanager -from django.contrib.auth.models import Permission, User +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Permission from django.utils.text import slugify from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site @@ -63,7 +64,7 @@ def create_test_user(username='testuser', permissions=None): """ Create a User with the given permissions. """ - user = User.objects.create_user(username=username) + user = get_user_model().objects.create_user(username=username) if permissions is None: permissions = () for perm_name in permissions: