diff --git a/netbox/ipam/forms/bulk_import.py b/netbox/ipam/forms/bulk_import.py index fccf98bd8..d6267bec7 100644 --- a/netbox/ipam/forms/bulk_import.py +++ b/netbox/ipam/forms/bulk_import.py @@ -559,19 +559,17 @@ class ServiceTemplateImportForm(NetBoxModelImportForm): class ServiceImportForm(NetBoxModelImportForm): - device = CSVModelChoiceField( - label=_('Device'), + parent_object_type = CSVContentTypeField( + queryset=ContentType.objects.filter(SERVICE_ASSIGNMENT_MODELS), + required=True, + label=_('Parent type (app & model)') + ) + parent = CSVModelChoiceField( + label=_('Parent'), queryset=Device.objects.all(), required=False, to_field_name='name', - help_text=_('Required if not assigned to a VM') - ) - virtual_machine = CSVModelChoiceField( - label=_('Virtual machine'), - queryset=VirtualMachine.objects.all(), - required=False, - to_field_name='name', - help_text=_('Required if not assigned to a device') + help_text=_('Parent object name') ) protocol = CSVChoiceField( label=_('Protocol'), @@ -588,15 +586,43 @@ class ServiceImportForm(NetBoxModelImportForm): class Meta: model = Service fields = ( - 'device', 'virtual_machine', 'ipaddresses', 'name', 'protocol', 'ports', 'description', 'comments', 'tags', + 'parent_object_type', 'ipaddresses', 'name', 'protocol', 'ports', 'description', 'comments', + 'tags', 'parent_object_id', ) - def clean_ipaddresses(self): - parent = self.cleaned_data.get('device') or self.cleaned_data.get('virtual_machine') - for ip_address in self.cleaned_data['ipaddresses']: + def __init__(self, data=None, *args, **kwargs): + super().__init__(data, *args, **kwargs) + + # Limit parent queryset by assigned parent object type + if data: + match data.get('parent_object_type'): + case 'dcim.device': + self.fields['parent'].queryset = Device.objects.all() + case 'ipam.fhrpgroup': + self.fields['parent'].queryset = FHRPGroup.objects.all() + case 'virtualization.virtualmachine': + self.fields['parent'].queryset = VirtualMachine.objects.all() + + def save(self, *args, **kwargs): + if (parent := self.cleaned_data.get('parent')): + self.instance.parent = parent + + return super().save(*args, **kwargs) + + def clean(self): + super().clean() + + if (parent := self.cleaned_data.get('parent')): + self.cleaned_data['parent_object_id'] = parent.pk + elif not parent and (parent_id := self.cleaned_data.get('parent_object_id')): + ct = self.cleaned_data.get('parent_object_type') + parent = ct.model_class().objects.filter(id=parent_id).first() + self.cleaned_data['parent'] = parent + + for ip_address in self.cleaned_data.get('ipaddresses', []): if not ip_address.assigned_object or getattr(ip_address.assigned_object, 'parent_object') != parent: raise forms.ValidationError( _("{ip} is not assigned to this device/VM.").format(ip=ip_address) ) - return self.cleaned_data['ipaddresses'] + return self.cleaned_data diff --git a/netbox/ipam/tests/test_views.py b/netbox/ipam/tests/test_views.py index ad9412979..8ba011d1d 100644 --- a/netbox/ipam/tests/test_views.py +++ b/netbox/ipam/tests/test_views.py @@ -5,11 +5,14 @@ from django.test import override_settings from django.urls import reverse from netaddr import IPNetwork +from core.models import ObjectType from dcim.constants import InterfaceTypeChoices from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site, Interface from ipam.choices import * from ipam.models import * +from netbox.choices import CSVDelimiterChoices, ImportFormatChoices from tenancy.models import Tenant +from users.models import ObjectPermission from utilities.testing import ViewTestCases, create_tags @@ -1093,10 +1096,10 @@ class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase): } cls.csv_data = ( - "device,name,protocol,ports,ipaddresses,description", - "Device 1,Service 1,tcp,1,192.0.2.1/24,First service", - "Device 1,Service 2,tcp,2,192.0.2.2/24,Second service", - "Device 1,Service 3,udp,3,,Third service", + "parent_object_type,parent,name,protocol,ports,ipaddresses,description", + "dcim.device,Device 1,Service 1,tcp,1,192.0.2.1/24,First service", + "dcim.device,Device 1,Service 2,tcp,2,192.0.2.2/24,Second service", + "dcim.device,Device 1,Service 3,udp,3,,Third service", ) cls.csv_update_data = ( @@ -1112,6 +1115,66 @@ class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase): 'description': 'New description', } + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], EXEMPT_EXCLUDE_MODELS=[]) + def test_unassigned_ip_addresses(self): + device = Device.objects.first() + addr = IPAddress.objects.create(address='192.0.2.4/24') + csv_data = ( + "parent_object_type,parent_object_id,name,protocol,ports,ipaddresses,description", + f"dcim.device,{device.pk},Service 11,tcp,10,{addr.address},Eleventh service", + ) + + initial_count = self._get_queryset().count() + data = { + 'data': '\n'.join(csv_data), + 'format': ImportFormatChoices.CSV, + 'csv_delimiter': CSVDelimiterChoices.AUTO, + } + + # Assign model-level permission + obj_perm = ObjectPermission.objects.create(name='Test permission', actions=['add']) + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + # Test POST with permission + response = self.client.post(self._get_url('bulk_import'), data) + + self.assertHttpStatus(response, 200) + form_errors = response.context['form'].errors + self.assertEqual(len(form_errors), 1) + self.assertIn(addr.address, form_errors['__all__'][0]) + self.assertEqual(self._get_queryset().count(), initial_count) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], EXEMPT_EXCLUDE_MODELS=[]) + def test_alternate_csv_import(self): + device = Device.objects.first() + interface = device.interfaces.first() + addr = IPAddress.objects.create(assigned_object=interface, address='192.0.2.3/24') + csv_data = ( + "parent_object_type,parent_object_id,name,protocol,ports,ipaddresses,description", + f"dcim.device,{device.pk},Service 11,tcp,10,{addr.address},Eleventh service", + ) + + initial_count = self._get_queryset().count() + data = { + 'data': '\n'.join(csv_data), + 'format': ImportFormatChoices.CSV, + 'csv_delimiter': CSVDelimiterChoices.AUTO, + } + + # Assign model-level permission + obj_perm = ObjectPermission.objects.create(name='Test permission', actions=['add']) + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + # Test POST with permission + response = self.client.post(self._get_url('bulk_import'), data) + + if response.status_code != 302: + self.assertEqual(response.context['form'].errors, {}) # debugging aid + self.assertHttpStatus(response, 302) + self.assertEqual(self._get_queryset().count(), initial_count + len(csv_data) - 1) + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) def test_create_from_template(self): self.add_permissions('ipam.add_service')