diff --git a/netbox/account/views.py b/netbox/account/views.py index 4fb3de840..c8cd0fe66 100644 --- a/netbox/account/views.py +++ b/netbox/account/views.py @@ -21,7 +21,7 @@ from social_core.backends.utils import load_backends from account.models import UserToken from core.models import ObjectChange from core.tables import ObjectChangeTable -from extras.models import Bookmark, Notification, Subscription +from extras.models import Bookmark from extras.tables import BookmarkTable, NotificationTable, SubscriptionTable from netbox.authentication import get_auth_backend_display, get_saml_idps from netbox.config import get_config diff --git a/netbox/extras/filtersets.py b/netbox/extras/filtersets.py index 505c342b4..f34270f07 100644 --- a/netbox/extras/filtersets.py +++ b/netbox/extras/filtersets.py @@ -338,7 +338,7 @@ class BookmarkFilterSet(BaseFilterSet): fields = ('id', 'object_id') -class NotificationGroupFilterSet(BaseFilterSet): +class NotificationGroupFilterSet(ChangeLoggedModelFilterSet): q = django_filters.CharFilter( method='search', label=_('Search'), @@ -348,11 +348,23 @@ class NotificationGroupFilterSet(BaseFilterSet): queryset=User.objects.all(), label=_('User (ID)'), ) + user = django_filters.ModelMultipleChoiceFilter( + field_name='users__username', + queryset=User.objects.all(), + to_field_name='username', + label=_('User (name)'), + ) group_id = django_filters.ModelMultipleChoiceFilter( field_name='groups', queryset=Group.objects.all(), label=_('Group (ID)'), ) + group = django_filters.ModelMultipleChoiceFilter( + field_name='groups__name', + queryset=Group.objects.all(), + to_field_name='name', + label=_('Group (name)'), + ) class Meta: model = NotificationGroup diff --git a/netbox/extras/models/notifications.py b/netbox/extras/models/notifications.py index 3280d0fa9..d3fa77ec1 100644 --- a/netbox/extras/models/notifications.py +++ b/netbox/extras/models/notifications.py @@ -206,6 +206,11 @@ class Subscription(models.Model): verbose_name = _('subscription') verbose_name_plural = _('subscriptions') + def __str__(self): + if self.object: + return str(self.object) + return super().__str__() + def get_absolute_url(self): return reverse('account:subscriptions') diff --git a/netbox/extras/tests/test_filtersets.py b/netbox/extras/tests/test_filtersets.py index 5c737f7cf..bf34f96b8 100644 --- a/netbox/extras/tests/test_filtersets.py +++ b/netbox/extras/tests/test_filtersets.py @@ -1,7 +1,6 @@ import uuid from datetime import datetime, timezone -from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.test import TestCase @@ -15,13 +14,11 @@ from extras.choices import * from extras.filtersets import * from extras.models import * from tenancy.models import Tenant, TenantGroup +from users.models import Group, User from utilities.testing import BaseFilterSetTests, ChangeLoggedFilterSetTests, create_tags from virtualization.models import Cluster, ClusterGroup, ClusterType -User = get_user_model() - - class CustomFieldTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = CustomField.objects.all() filterset = CustomFieldFilterSet @@ -1370,3 +1367,65 @@ class ChangeLoggedFilterSetTestCase(TestCase): params = {'modified_by_request': self.create_update_request_id} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) self.assertEqual(self.queryset.count(), 4) + + +class NotificationGroupTestCase(TestCase, BaseFilterSetTests): + queryset = NotificationGroup.objects.all() + filterset = NotificationGroupFilterSet + + @classmethod + def setUpTestData(cls): + users = ( + User(username='User 1'), + User(username='User 2'), + User(username='User 3'), + ) + User.objects.bulk_create(users) + + groups = ( + Group(name='Group 1'), + Group(name='Group 2'), + Group(name='Group 3'), + ) + Group.objects.bulk_create(groups) + + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), + Site(name='Site 3', slug='site-3'), + ) + Site.objects.bulk_create(sites) + + tenants = ( + Tenant(name='Tenant 1', slug='tenant-1'), + Tenant(name='Tenant 2', slug='tenant-2'), + Tenant(name='Tenant 3', slug='tenant-3'), + ) + Tenant.objects.bulk_create(tenants) + + notification_groups = ( + NotificationGroup(name='Notification Group 1'), + NotificationGroup(name='Notification Group 2'), + NotificationGroup(name='Notification Group 3'), + ) + NotificationGroup.objects.bulk_create(notification_groups) + notification_groups[0].users.add(users[0]) + notification_groups[1].users.add(users[1]) + notification_groups[2].users.add(users[2]) + notification_groups[0].groups.add(groups[0]) + notification_groups[1].groups.add(groups[1]) + notification_groups[2].groups.add(groups[2]) + + def test_user(self): + users = User.objects.filter(username__startswith='User') + params = {'user': [users[0].username, users[1].username]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + params = {'user_id': [users[0].pk, users[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_group(self): + groups = Group.objects.all() + params = {'group': [groups[0].name, groups[1].name]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + params = {'group_id': [groups[0].pk, groups[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) diff --git a/netbox/extras/tests/test_views.py b/netbox/extras/tests/test_views.py index cbede195b..552c0f57a 100644 --- a/netbox/extras/tests/test_views.py +++ b/netbox/extras/tests/test_views.py @@ -1,4 +1,3 @@ -from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.urls import reverse @@ -6,10 +5,9 @@ from core.models import ObjectType from dcim.models import DeviceType, Manufacturer, Site from extras.choices import * from extras.models import * +from users.models import Group, User from utilities.testing import ViewTestCases, TestCase -User = get_user_model() - class CustomFieldTestCase(ViewTestCases.PrimaryObjectViewTestCase): model = CustomField @@ -620,3 +618,166 @@ class CustomLinkTest(TestCase): response = self.client.get(site.get_absolute_url(), follow=True) self.assertEqual(response.status_code, 200) self.assertIn(f'FOO {site.name} BAR', str(response.content)) + + +class SubscriptionTestCase( + ViewTestCases.CreateObjectViewTestCase, + ViewTestCases.DeleteObjectViewTestCase, + ViewTestCases.ListObjectsViewTestCase, + ViewTestCases.BulkDeleteObjectsViewTestCase +): + model = Subscription + + @classmethod + def setUpTestData(cls): + site_ct = ContentType.objects.get_for_model(Site) + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), + Site(name='Site 3', slug='site-3'), + Site(name='Site 4', slug='site-4'), + ) + Site.objects.bulk_create(sites) + + cls.form_data = { + 'object_type': site_ct.pk, + 'object_id': sites[3].pk, + } + + def setUp(self): + super().setUp() + + sites = Site.objects.all() + user = self.user + + subscriptions = ( + Subscription(object=sites[0], user=user), + Subscription(object=sites[1], user=user), + Subscription(object=sites[2], user=user), + ) + Subscription.objects.bulk_create(subscriptions) + + def _get_url(self, action, instance=None): + if action == 'list': + return reverse('account:subscriptions') + return super()._get_url(action, instance) + + def test_list_objects_anonymous(self): + self.client.logout() + url = reverse('account:subscriptions') + login_url = reverse('login') + self.assertRedirects(self.client.get(url), f'{login_url}?next={url}') + + def test_list_objects_with_permission(self): + return + + def test_list_objects_with_constrained_permission(self): + return + + +class NotificationGroupTestCase(ViewTestCases.PrimaryObjectViewTestCase): + model = NotificationGroup + + @classmethod + def setUpTestData(cls): + users = ( + User(username='User 1'), + User(username='User 2'), + User(username='User 3'), + ) + User.objects.bulk_create(users) + groups = ( + Group(name='Group 1'), + Group(name='Group 2'), + Group(name='Group 3'), + ) + Group.objects.bulk_create(groups) + + notification_groups = ( + NotificationGroup(name='Notification Group 1'), + NotificationGroup(name='Notification Group 2'), + NotificationGroup(name='Notification Group 3'), + ) + NotificationGroup.objects.bulk_create(notification_groups) + for i, notification_group in enumerate(notification_groups): + notification_group.users.add(users[i]) + notification_group.groups.add(groups[i]) + + cls.form_data = { + 'name': 'Notification Group X', + 'description': 'Blah', + 'users': [users[0].pk, users[1].pk], + 'groups': [groups[0].pk, groups[1].pk], + } + + cls.csv_data = ( + 'name,description,users,groups', + 'Notification Group 4,Foo,"User 1,User 2","Group 1,Group 2"', + 'Notification Group 5,Bar,"User 1,User 2","Group 1,Group 2"', + 'Notification Group 6,Baz,"User 1,User 2","Group 1,Group 2"', + ) + + cls.csv_update_data = ( + "id,name", + f"{notification_groups[0].pk},Notification Group 7", + f"{notification_groups[1].pk},Notification Group 8", + f"{notification_groups[2].pk},Notification Group 9", + ) + + cls.bulk_edit_data = { + 'description': 'New description', + } + + +class NotificationTestCase( + ViewTestCases.DeleteObjectViewTestCase, + ViewTestCases.ListObjectsViewTestCase, + ViewTestCases.BulkDeleteObjectsViewTestCase +): + model = Notification + + @classmethod + def setUpTestData(cls): + site_ct = ContentType.objects.get_for_model(Site) + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), + Site(name='Site 3', slug='site-3'), + Site(name='Site 4', slug='site-4'), + ) + Site.objects.bulk_create(sites) + + cls.form_data = { + 'object_type': site_ct.pk, + 'object_id': sites[3].pk, + } + + def setUp(self): + super().setUp() + + sites = Site.objects.all() + user = self.user + + notifications = ( + Notification(object=sites[0], user=user), + Notification(object=sites[1], user=user), + Notification(object=sites[2], user=user), + ) + Notification.objects.bulk_create(notifications) + + def _get_url(self, action, instance=None): + if action == 'list': + return reverse('account:notifications') + return super()._get_url(action, instance) + + def test_list_objects_anonymous(self): + self.client.logout() + url = reverse('account:notifications') + login_url = reverse('login') + self.assertRedirects(self.client.get(url), f'{login_url}?next={url}') + + def test_list_objects_with_permission(self): + return + + def test_list_objects_with_constrained_permission(self): + return