Compare commits

...

3 Commits

Author SHA1 Message Date
Jeremy Stretch
653a5ea058 Merge fa70430942 into 513b11450d 2025-11-27 11:12:10 +01:00
Martin Hauser
513b11450d Closes #20834: Add support for enabling/disabling Tokens (#20864)
Some checks failed
CI / build (20.x, 3.12) (push) Has been cancelled
CI / build (20.x, 3.13) (push) Has been cancelled
CodeQL / Analyze (${{ matrix.language }}) (none, actions) (push) Has been cancelled
CodeQL / Analyze (${{ matrix.language }}) (none, javascript-typescript) (push) Has been cancelled
CodeQL / Analyze (${{ matrix.language }}) (none, python) (push) Has been cancelled
* feat(users): Add support for enabling/disabling Tokens

Introduce an `enabled` flag on the `Token` model to allow temporarily
revoking API tokens without deleting them. Update forms, serializers,
and views to expose the new field.
Enforce the `enabled` flag in token authentication.
Add model, API, and authentication tests for the new behavior.

Fixes #20834

* Fix authentication test

---------

Co-authored-by: Jeremy Stretch <jstretch@netboxlabs.com>
2025-11-26 17:15:14 -05:00
Martin Hauser
b5edfa5d53 feat(extras): Inherit ConfigContext from ancestor platforms
Apply ConfigContext to objects whose platforms descend from any
assigned platform. This aligns platform behavior with regions, site
groups, locations, and roles.

Fixes #20639
2025-11-26 16:07:50 -05:00
17 changed files with 157 additions and 31 deletions

View File

@@ -46,6 +46,10 @@ class ConfigContextQuerySet(RestrictedQuerySet):
# Match against the directly assigned role as well as any parent roles. # Match against the directly assigned role as well as any parent roles.
device_roles = obj.role.get_ancestors(include_self=True) if obj.role else [] device_roles = obj.role.get_ancestors(include_self=True) if obj.role else []
# Match against the directly assigned platform as well as any parent platforms.
platform = getattr(obj, 'platform', None)
platforms = platform.get_ancestors(include_self=True) if platform else []
queryset = self.filter( queryset = self.filter(
Q(regions__in=regions) | Q(regions=None), Q(regions__in=regions) | Q(regions=None),
Q(site_groups__in=sitegroups) | Q(site_groups=None), Q(site_groups__in=sitegroups) | Q(site_groups=None),
@@ -53,7 +57,7 @@ class ConfigContextQuerySet(RestrictedQuerySet):
Q(locations__in=locations) | Q(locations=None), Q(locations__in=locations) | Q(locations=None),
Q(device_types=device_type) | Q(device_types=None), Q(device_types=device_type) | Q(device_types=None),
Q(roles__in=device_roles) | Q(roles=None), Q(roles__in=device_roles) | Q(roles=None),
Q(platforms=obj.platform) | Q(platforms=None), Q(platforms__in=platforms) | Q(platforms=None),
Q(cluster_types=cluster_type) | Q(cluster_types=None), Q(cluster_types=cluster_type) | Q(cluster_types=None),
Q(cluster_groups=cluster_group) | Q(cluster_groups=None), Q(cluster_groups=cluster_group) | Q(cluster_groups=None),
Q(clusters=cluster) | Q(clusters=None), Q(clusters=cluster) | Q(clusters=None),
@@ -103,7 +107,6 @@ class ConfigContextModelQuerySet(RestrictedQuerySet):
"content_type__model": self.model._meta.model_name "content_type__model": self.model._meta.model_name
} }
base_query = Q( base_query = Q(
Q(platforms=OuterRef('platform')) | Q(platforms=None),
Q(cluster_types=OuterRef('cluster__type')) | Q(cluster_types=None), Q(cluster_types=OuterRef('cluster__type')) | Q(cluster_types=None),
Q(cluster_groups=OuterRef('cluster__group')) | Q(cluster_groups=None), Q(cluster_groups=OuterRef('cluster__group')) | Q(cluster_groups=None),
Q(clusters=OuterRef('cluster')) | Q(clusters=None), Q(clusters=OuterRef('cluster')) | Q(clusters=None),
@@ -167,6 +170,15 @@ class ConfigContextModelQuerySet(RestrictedQuerySet):
) | Q(roles=None)), ) | Q(roles=None)),
Q.AND Q.AND
) )
base_query.add(
(Q(
platforms__tree_id=OuterRef('platform__tree_id'),
platforms__level__lte=OuterRef('platform__level'),
platforms__lft__lte=OuterRef('platform__lft'),
platforms__rght__gte=OuterRef('platform__rght'),
) | Q(platforms=None)),
Q.AND
)
return base_query return base_query

View File

@@ -38,7 +38,7 @@ class TokenAuthentication(BaseAuthentication):
try: try:
auth_value = auth[1].decode() auth_value = auth[1].decode()
except UnicodeError: except UnicodeError:
raise exceptions.AuthenticationFailed("Invalid authorization header: Token contains invalid characters") raise exceptions.AuthenticationFailed('Invalid authorization header: Token contains invalid characters')
# Infer token version from presence or absence of prefix # Infer token version from presence or absence of prefix
version = 2 if auth_value.startswith(TOKEN_PREFIX) else 1 version = 2 if auth_value.startswith(TOKEN_PREFIX) else 1
@@ -75,17 +75,21 @@ class TokenAuthentication(BaseAuthentication):
client_ip = get_client_ip(request) client_ip = get_client_ip(request)
if client_ip is None: if client_ip is None:
raise exceptions.AuthenticationFailed( raise exceptions.AuthenticationFailed(
"Client IP address could not be determined for validation. Check that the HTTP server is " 'Client IP address could not be determined for validation. Check that the HTTP server is '
"correctly configured to pass the required header(s)." 'correctly configured to pass the required header(s).'
) )
if not token.validate_client_ip(client_ip): if not token.validate_client_ip(client_ip):
raise exceptions.AuthenticationFailed( raise exceptions.AuthenticationFailed(
f"Source IP {client_ip} is not permitted to authenticate using this token." f"Source IP {client_ip} is not permitted to authenticate using this token."
) )
# Enforce the Token is enabled
if not token.enabled:
raise exceptions.AuthenticationFailed('Token disabled')
# Enforce the Token's expiration time, if one has been set. # Enforce the Token's expiration time, if one has been set.
if token.is_expired: if token.is_expired:
raise exceptions.AuthenticationFailed("Token expired") raise exceptions.AuthenticationFailed('Token expired')
# Update last used, but only once per minute at most. This reduces write load on the database # Update last used, but only once per minute at most. This reduces write load on the database
if not token.last_used or (timezone.now() - token.last_used).total_seconds() > 60: if not token.last_used or (timezone.now() - token.last_used).total_seconds() > 60:

View File

@@ -66,6 +66,32 @@ class TokenAuthenticationTestCase(APITestCase):
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
self.assertEqual(response.data['detail'], "Invalid v2 token") self.assertEqual(response.data['detail'], "Invalid v2 token")
@override_settings(LOGIN_REQUIRED=True, EXEMPT_VIEW_PERMISSIONS=['*'])
def test_token_enabled(self):
url = reverse('dcim-api:site-list')
# Create v1 & v2 tokens
token1 = Token.objects.create(version=1, user=self.user, enabled=True)
token2 = Token.objects.create(version=2, user=self.user, enabled=True)
# Request with an enabled token should succeed
response = self.client.get(url, HTTP_AUTHORIZATION=f'Token {token1.token}')
self.assertEqual(response.status_code, 200)
response = self.client.get(url, HTTP_AUTHORIZATION=f'Bearer {TOKEN_PREFIX}{token2.key}.{token2.token}')
self.assertEqual(response.status_code, 200)
# Request with a disabled token should fail
token1.enabled = False
token1.save()
token2.enabled = False
token2.save()
response = self.client.get(url, HTTP_AUTHORIZATION=f'Token {token1.token}')
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data['detail'], 'Token disabled')
response = self.client.get(url, HTTP_AUTHORIZATION=f'Bearer {TOKEN_PREFIX}{token2.key}.{token2.token}')
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data['detail'], 'Token disabled')
@override_settings(LOGIN_REQUIRED=True, EXEMPT_VIEW_PERMISSIONS=['*']) @override_settings(LOGIN_REQUIRED=True, EXEMPT_VIEW_PERMISSIONS=['*'])
def test_token_expiration(self): def test_token_expiration(self):
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')

View File

@@ -42,6 +42,10 @@
<th scope="row">{% trans "Description" %}</th> <th scope="row">{% trans "Description" %}</th>
<td>{{ object.description|placeholder }}</td> <td>{{ object.description|placeholder }}</td>
</tr> </tr>
<tr>
<th scope="row">{% trans "Enabled" %}</th>
<td>{% checkmark object.enabled %}</td>
</tr>
<tr> <tr>
<th scope="row">{% trans "Write enabled" %}</th> <th scope="row">{% trans "Write enabled" %}</th>
<td>{% checkmark object.write_enabled %}</td> <td>{% checkmark object.write_enabled %}</td>

View File

@@ -32,10 +32,10 @@ class TokenSerializer(ValidatedModelSerializer):
model = Token model = Token
fields = ( fields = (
'id', 'url', 'display_url', 'display', 'version', 'key', 'user', 'description', 'created', 'expires', 'id', 'url', 'display_url', 'display', 'version', 'key', 'user', 'description', 'created', 'expires',
'last_used', 'write_enabled', 'pepper_id', 'allowed_ips', 'token', 'last_used', 'enabled', 'write_enabled', 'pepper_id', 'allowed_ips', 'token',
) )
read_only_fields = ('key',) read_only_fields = ('key',)
brief_fields = ('id', 'url', 'display', 'version', 'key', 'write_enabled', 'description') brief_fields = ('id', 'url', 'display', 'version', 'key', 'enabled', 'write_enabled', 'description')
def get_fields(self): def get_fields(self):
fields = super().get_fields() fields = super().get_fields()
@@ -79,7 +79,7 @@ class TokenProvisionSerializer(TokenSerializer):
model = Token model = Token
fields = ( fields = (
'id', 'url', 'display_url', 'display', 'version', 'user', 'key', 'created', 'expires', 'last_used', 'key', 'id', 'url', 'display_url', 'display', 'version', 'user', 'key', 'created', 'expires', 'last_used', 'key',
'write_enabled', 'description', 'allowed_ips', 'username', 'password', 'token', 'enabled', 'write_enabled', 'description', 'allowed_ips', 'username', 'password', 'token',
) )
def validate(self, data): def validate(self, data):

View File

@@ -167,7 +167,8 @@ class TokenFilterSet(BaseFilterSet):
class Meta: class Meta:
model = Token model = Token
fields = ( fields = (
'id', 'version', 'key', 'pepper_id', 'write_enabled', 'description', 'created', 'expires', 'last_used', 'id', 'version', 'key', 'pepper_id', 'enabled', 'write_enabled',
'description', 'created', 'expires', 'last_used',
) )
def search(self, queryset, name, value): def search(self, queryset, name, value):

View File

@@ -99,6 +99,11 @@ class TokenBulkEditForm(BulkEditForm):
queryset=Token.objects.all(), queryset=Token.objects.all(),
widget=forms.MultipleHiddenInput widget=forms.MultipleHiddenInput
) )
enabled = forms.NullBooleanField(
required=False,
widget=BulkEditNullBooleanSelect,
label=_('Enabled')
)
write_enabled = forms.NullBooleanField( write_enabled = forms.NullBooleanField(
required=False, required=False,
widget=BulkEditNullBooleanSelect, widget=BulkEditNullBooleanSelect,
@@ -122,7 +127,7 @@ class TokenBulkEditForm(BulkEditForm):
model = Token model = Token
fieldsets = ( fieldsets = (
FieldSet('write_enabled', 'description', 'expires', 'allowed_ips'), FieldSet('enabled', 'write_enabled', 'description', 'expires', 'allowed_ips'),
) )
nullable_fields = ( nullable_fields = (
'expires', 'description', 'allowed_ips', 'expires', 'description', 'allowed_ips',

View File

@@ -52,7 +52,7 @@ class TokenImportForm(CSVModelForm):
class Meta: class Meta:
model = Token model = Token
fields = ('user', 'version', 'token', 'write_enabled', 'expires', 'description',) fields = ('user', 'version', 'token', 'enabled', 'write_enabled', 'expires', 'description',)
class OwnerGroupImportForm(CSVModelForm): class OwnerGroupImportForm(CSVModelForm):

View File

@@ -114,7 +114,7 @@ class TokenFilterForm(SavedFiltersMixin, FilterForm):
model = Token model = Token
fieldsets = ( fieldsets = (
FieldSet('q', 'filter_id',), FieldSet('q', 'filter_id',),
FieldSet('version', 'user_id', 'write_enabled', 'expires', 'last_used', name=_('Token')), FieldSet('version', 'user_id', 'enabled', 'write_enabled', 'expires', 'last_used', name=_('Token')),
) )
version = forms.ChoiceField( version = forms.ChoiceField(
choices=add_blank_choice(TokenVersionChoices), choices=add_blank_choice(TokenVersionChoices),
@@ -125,6 +125,13 @@ class TokenFilterForm(SavedFiltersMixin, FilterForm):
required=False, required=False,
label=_('User') label=_('User')
) )
enabled = forms.NullBooleanField(
required=False,
widget=forms.Select(
choices=BOOLEAN_WITH_BLANK_CHOICES
),
label=_('Enabled'),
)
write_enabled = forms.NullBooleanField( write_enabled = forms.NullBooleanField(
required=False, required=False,
widget=forms.Select( widget=forms.Select(

View File

@@ -140,7 +140,7 @@ class UserTokenForm(forms.ModelForm):
class Meta: class Meta:
model = Token model = Token
fields = [ fields = [
'version', 'token', 'write_enabled', 'expires', 'description', 'allowed_ips', 'version', 'token', 'enabled', 'write_enabled', 'expires', 'description', 'allowed_ips',
] ]
widgets = { widgets = {
'expires': DateTimePicker(), 'expires': DateTimePicker(),
@@ -177,7 +177,7 @@ class TokenForm(UserTokenForm):
class Meta(UserTokenForm.Meta): class Meta(UserTokenForm.Meta):
fields = [ fields = [
'version', 'token', 'user', 'write_enabled', 'expires', 'description', 'allowed_ips', 'version', 'token', 'user', 'enabled', 'write_enabled', 'expires', 'description', 'allowed_ips',
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View File

@@ -9,6 +9,13 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
# Add a new field to enable/disable tokens
migrations.AddField(
model_name='token',
name='enabled',
field=models.BooleanField(default=True),
),
# Rename the original key field to "plaintext" # Rename the original key field to "plaintext"
migrations.RenameField( migrations.RenameField(
model_name='token', model_name='token',
@@ -35,7 +42,7 @@ class Migration(migrations.Migration):
), ),
), ),
# Add version field to distinguish v1 and v2 tokens # Add a version field to distinguish v1 and v2 tokens
migrations.AddField( migrations.AddField(
model_name='token', model_name='token',
name='version', name='version',

View File

@@ -61,6 +61,11 @@ class Token(models.Model):
blank=True, blank=True,
null=True null=True
) )
enabled = models.BooleanField(
verbose_name=_('enabled'),
default=True,
help_text=_('Disable to temporarily revoke this token without deleting it.'),
)
write_enabled = models.BooleanField( write_enabled = models.BooleanField(
verbose_name=_('write enabled'), verbose_name=_('write enabled'),
default=True, default=True,
@@ -180,6 +185,22 @@ class Token(models.Model):
self.key = self.key or self.generate_key() self.key = self.key or self.generate_key()
self.update_digest() self.update_digest()
@property
def is_expired(self):
"""
Check whether the token has expired.
"""
if self.expires is None or timezone.now() < self.expires:
return False
return True
@property
def is_active(self):
"""
Check whether the token is active (enabled and not expired).
"""
return self.enabled and not self.is_expired
def clean(self): def clean(self):
super().clean() super().clean()
@@ -236,12 +257,6 @@ class Token(models.Model):
hashlib.sha256 hashlib.sha256
).hexdigest() ).hexdigest()
@property
def is_expired(self):
if self.expires is None or timezone.now() < self.expires:
return False
return True
def validate(self, token): def validate(self, token):
""" """
Validate the given plaintext against the token. Validate the given plaintext against the token.

View File

@@ -25,6 +25,9 @@ class TokenTable(NetBoxTable):
verbose_name=_('token'), verbose_name=_('token'),
template_code=TOKEN, template_code=TOKEN,
) )
enabled = columns.BooleanColumn(
verbose_name=_('Enabled')
)
write_enabled = columns.BooleanColumn( write_enabled = columns.BooleanColumn(
verbose_name=_('Write Enabled') verbose_name=_('Write Enabled')
) )
@@ -49,10 +52,10 @@ class TokenTable(NetBoxTable):
class Meta(NetBoxTable.Meta): class Meta(NetBoxTable.Meta):
model = Token model = Token
fields = ( fields = (
'pk', 'id', 'token', 'version', 'pepper_id', 'user', 'description', 'write_enabled', 'created', 'expires', 'pk', 'id', 'token', 'version', 'pepper_id', 'user', 'description', 'enabled', 'write_enabled', 'created',
'last_used', 'allowed_ips', 'expires', 'last_used', 'allowed_ips',
) )
default_columns = ('token', 'version', 'user', 'write_enabled', 'description', 'allowed_ips') default_columns = ('token', 'version', 'user', 'enabled', 'write_enabled', 'description', 'allowed_ips')
class UserTable(NetBoxTable): class UserTable(NetBoxTable):

View File

@@ -195,10 +195,10 @@ class TokenTest(
APIViewTestCases.ListObjectsViewTestCase, APIViewTestCases.ListObjectsViewTestCase,
APIViewTestCases.CreateObjectViewTestCase, APIViewTestCases.CreateObjectViewTestCase,
APIViewTestCases.UpdateObjectViewTestCase, APIViewTestCases.UpdateObjectViewTestCase,
APIViewTestCases.DeleteObjectViewTestCase APIViewTestCases.DeleteObjectViewTestCase,
): ):
model = Token model = Token
brief_fields = ['description', 'display', 'id', 'key', 'url', 'version', 'write_enabled'] brief_fields = ['description', 'display', 'enabled', 'id', 'key', 'url', 'version', 'write_enabled']
bulk_update_data = { bulk_update_data = {
'description': 'New description', 'description': 'New description',
} }
@@ -229,12 +229,16 @@ class TokenTest(
cls.create_data = [ cls.create_data = [
{ {
'user': users[0].pk, 'user': users[0].pk,
'enabled': True,
}, },
{ {
'user': users[1].pk, 'user': users[1].pk,
'enabled': False,
}, },
{ {
'user': users[2].pk, 'user': users[2].pk,
'enabled': True,
'write_enabled': False,
}, },
] ]
@@ -267,6 +271,8 @@ class TokenTest(
self.assertEqual(response.data['expires'], data['expires']) self.assertEqual(response.data['expires'], data['expires'])
token = Token.objects.get(user=user) token = Token.objects.get(user=user)
self.assertEqual(token.key, response.data['key']) self.assertEqual(token.key, response.data['key'])
self.assertEqual(token.enabled, response.data['enabled'])
self.assertEqual(token.write_enabled, response.data['write_enabled'])
def test_provision_token_invalid(self): def test_provision_token_invalid(self):
""" """

View File

@@ -285,6 +285,7 @@ class TokenTestCase(TestCase, BaseFilterSetTests):
version=1, version=1,
user=users[0], user=users[0],
expires=future_date, expires=future_date,
enabled=True,
write_enabled=True, write_enabled=True,
description='foobar1', description='foobar1',
), ),
@@ -292,12 +293,14 @@ class TokenTestCase(TestCase, BaseFilterSetTests):
version=2, version=2,
user=users[1], user=users[1],
expires=future_date, expires=future_date,
enabled=False,
write_enabled=True, write_enabled=True,
description='foobar2', description='foobar2',
), ),
Token( Token(
version=2, version=2,
user=users[2], user=users[2],
enabled=True,
expires=past_date, expires=past_date,
write_enabled=False, write_enabled=False,
), ),
@@ -339,6 +342,12 @@ class TokenTestCase(TestCase, BaseFilterSetTests):
params = {'expires__lte': '2021-01-01T00:00:00'} params = {'expires__lte': '2021-01-01T00:00:00'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_enabled(self):
params = {'enabled': True}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'enabled': False}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_write_enabled(self): def test_write_enabled(self):
params = {'write_enabled': True} params = {'write_enabled': True}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)

View File

@@ -20,6 +20,32 @@ class TokenTest(TestCase):
""" """
cls.user = create_test_user('User 1') cls.user = create_test_user('User 1')
def test_is_active(self):
"""
Test the is_active property.
"""
# Token with enabled status and no expiration date
token = Token(user=self.user, enabled=True, expires=None)
self.assertTrue(token.is_active)
# Token with disabled status
token.enabled = False
self.assertFalse(token.is_active)
# Token with enabled status and future expiration
future_date = timezone.now() + timedelta(days=1)
token = Token(user=self.user, enabled=True, expires=future_date)
self.assertTrue(token.is_active)
# Token with past expiration
token.expires = timezone.now() - timedelta(days=1)
self.assertFalse(token.is_active)
# Token with disabled status and past expiration
past_date = timezone.now() - timedelta(days=1)
token = Token(user=self.user, enabled=False, expires=past_date)
self.assertFalse(token.is_active)
def test_is_expired(self): def test_is_expired(self):
""" """
Test the is_expired property. Test the is_expired property.

View File

@@ -236,13 +236,14 @@ class TokenTestCase(
'token': '4F9DAouzURLbicyoG55htImgqQ0b4UZHP5LUYgl5', 'token': '4F9DAouzURLbicyoG55htImgqQ0b4UZHP5LUYgl5',
'user': users[0].pk, 'user': users[0].pk,
'description': 'Test token', 'description': 'Test token',
'enabled': True,
} }
cls.csv_data = ( cls.csv_data = (
"token,user,description", "token,user,description,enabled,write_enabled",
f"zjebxBPzICiPbWz0Wtx0fTL7bCKXKGTYhNzkgC2S,{users[0].pk},Test token", f"zjebxBPzICiPbWz0Wtx0fTL7bCKXKGTYhNzkgC2S,{users[0].pk},Test token,true,true",
f"9Z5kGtQWba60Vm226dPDfEAV6BhlTr7H5hAXAfbF,{users[1].pk},Test token", f"9Z5kGtQWba60Vm226dPDfEAV6BhlTr7H5hAXAfbF,{users[1].pk},Test token,true,false",
f"njpMnNT6r0k0MDccoUhTYYlvP9BvV3qLzYN2p6Uu,{users[1].pk},Test token", f"njpMnNT6r0k0MDccoUhTYYlvP9BvV3qLzYN2p6Uu,{users[1].pk},Test token,false,true",
) )
cls.csv_update_data = ( cls.csv_update_data = (