diff --git a/netbox/circuits/tests/test_views.py b/netbox/circuits/tests/test_views.py index afb16ddc7..54d001c8d 100644 --- a/netbox/circuits/tests/test_views.py +++ b/netbox/circuits/tests/test_views.py @@ -51,10 +51,10 @@ class ProviderTestCase(ViewTestCases.PrimaryObjectViewTestCase): ) cls.csv_update_data = ( - "name,comments", - "Provider 7,New comment7", - "Provider 8,New comment8", - "Provider 9,New comment9", + "id,name,comments", + f"{providers[0].pk},Provider 7,New comment7", + f"{providers[1].pk},Provider 8,New comment8", + f"{providers[2].pk},Provider 9,New comment9", ) cls.bulk_edit_data = { @@ -69,11 +69,13 @@ class CircuitTypeTestCase(ViewTestCases.OrganizationalObjectViewTestCase): @classmethod def setUpTestData(cls): - CircuitType.objects.bulk_create([ + circuit_types = ( CircuitType(name='Circuit Type 1', slug='circuit-type-1'), CircuitType(name='Circuit Type 2', slug='circuit-type-2'), CircuitType(name='Circuit Type 3', slug='circuit-type-3'), - ]) + ) + + CircuitType.objects.bulk_create(circuit_types) tags = create_tags('Alpha', 'Bravo', 'Charlie') @@ -92,10 +94,10 @@ class CircuitTypeTestCase(ViewTestCases.OrganizationalObjectViewTestCase): ) cls.csv_update_data = ( - "name,description", - "Circuit Type 7,New description7", - "Circuit Type 8,New description8", - "Circuit Type 9,New description9", + "id,name,description", + f"{circuit_types[0].pk},Circuit Type 7,New description7", + f"{circuit_types[1].pk},Circuit Type 8,New description8", + f"{circuit_types[2].pk},Circuit Type 9,New description9", ) cls.bulk_edit_data = { @@ -121,11 +123,13 @@ class CircuitTestCase(ViewTestCases.PrimaryObjectViewTestCase): ) CircuitType.objects.bulk_create(circuittypes) - Circuit.objects.bulk_create([ + circuits = ( Circuit(cid='Circuit 1', provider=providers[0], type=circuittypes[0]), Circuit(cid='Circuit 2', provider=providers[0], type=circuittypes[0]), Circuit(cid='Circuit 3', provider=providers[0], type=circuittypes[0]), - ]) + ) + + Circuit.objects.bulk_create(circuits) tags = create_tags('Alpha', 'Bravo', 'Charlie') @@ -151,10 +155,10 @@ class CircuitTestCase(ViewTestCases.PrimaryObjectViewTestCase): ) cls.csv_update_data = ( - f"cid,description,status", - f"Circuit 7,New description7,{CircuitStatusChoices.STATUS_DECOMMISSIONED}", - f"Circuit 8,New description8,{CircuitStatusChoices.STATUS_DECOMMISSIONED}", - f"Circuit 9,New description9,{CircuitStatusChoices.STATUS_DECOMMISSIONED}", + f"id,cid,description,status", + f"{circuits[0].pk},Circuit 7,New description7,{CircuitStatusChoices.STATUS_DECOMMISSIONED}", + f"{circuits[1].pk},Circuit 8,New description8,{CircuitStatusChoices.STATUS_DECOMMISSIONED}", + f"{circuits[2].pk},Circuit 9,New description9,{CircuitStatusChoices.STATUS_DECOMMISSIONED}", ) cls.bulk_edit_data = { @@ -180,11 +184,13 @@ class ProviderNetworkTestCase(ViewTestCases.PrimaryObjectViewTestCase): ) Provider.objects.bulk_create(providers) - ProviderNetwork.objects.bulk_create([ + provider_networks = ( ProviderNetwork(name='Provider Network 1', provider=providers[0]), ProviderNetwork(name='Provider Network 2', provider=providers[0]), ProviderNetwork(name='Provider Network 3', provider=providers[0]), - ]) + ) + + ProviderNetwork.objects.bulk_create(provider_networks) tags = create_tags('Alpha', 'Bravo', 'Charlie') @@ -204,10 +210,10 @@ class ProviderNetworkTestCase(ViewTestCases.PrimaryObjectViewTestCase): ) cls.csv_update_data = ( - "name,description", - "Provider Network 7,New description7", - "Provider Network 8,New description8", - "Provider Network 9,New description9", + "id,name,description", + f"{provider_networks[0].pk},Provider Network 7,New description7", + f"{provider_networks[1].pk},Provider Network 8,New description8", + f"{provider_networks[2].pk},Provider Network 9,New description9", ) cls.bulk_edit_data = { diff --git a/netbox/utilities/testing/views.py b/netbox/utilities/testing/views.py index 126a38087..2309ee1d2 100644 --- a/netbox/utilities/testing/views.py +++ b/netbox/utilities/testing/views.py @@ -550,16 +550,8 @@ class ViewTestCases: def _get_csv_data(self): return '\n'.join(self.csv_data) - def _get_update_csv_data(self, start): - # pre-pend id into data - csv_data = [] - for idx, line in enumerate(self.csv_update_data, start=start): - if idx == start: - csv_data.append("id," + line) - else: - csv_data.append(f"{idx-1}," + line) - - return csv_data, '\n'.join(csv_data) + def _get_update_csv_data(self): + return self.csv_update_data, '\n'.join(self.csv_update_data) def test_bulk_import_objects_without_permission(self): data = { @@ -601,7 +593,7 @@ class ViewTestCases: @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) def test_bulk_update_objects_with_permission(self): if not self.csv_update_data: - return + raise NotImplementedError("The test must define csv_update_data.") data = { 'csv': self._get_csv_data(), @@ -616,17 +608,11 @@ class ViewTestCases: obj_perm.users.add(self.user) obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) - # need to track ids so we know what new ids were added by csv_data so we can - # do the updates on the appropriate ids - prev_ids = list(self._get_queryset().values_list('id', flat=True).order_by('id')) self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) count = self._get_queryset().count() - new_ids = list(self._get_queryset().values_list('id', flat=True).order_by('id')) - diff_ids = [x for x in new_ids if x not in prev_ids] - start_id = diff_ids[0] # Now try update the data - array, csv_data = self._get_update_csv_data(start_id) + array, csv_data = self._get_update_csv_data() data = { 'csv': csv_data, }