Fixes up ServiceImportForm to work with new parent field

This commit is contained in:
Jason Novinger 2025-04-08 12:53:21 -05:00
parent 0fdadd0637
commit f53e3103ff
2 changed files with 108 additions and 19 deletions

View File

@ -559,19 +559,17 @@ class ServiceTemplateImportForm(NetBoxModelImportForm):
class ServiceImportForm(NetBoxModelImportForm):
device = CSVModelChoiceField(
label=_('Device'),
parent_object_type = CSVContentTypeField(
queryset=ContentType.objects.filter(SERVICE_ASSIGNMENT_MODELS),
required=True,
label=_('Parent type (app & model)')
)
parent = CSVModelChoiceField(
label=_('Parent'),
queryset=Device.objects.all(),
required=False,
to_field_name='name',
help_text=_('Required if not assigned to a VM')
)
virtual_machine = CSVModelChoiceField(
label=_('Virtual machine'),
queryset=VirtualMachine.objects.all(),
required=False,
to_field_name='name',
help_text=_('Required if not assigned to a device')
help_text=_('Parent object name')
)
protocol = CSVChoiceField(
label=_('Protocol'),
@ -588,15 +586,43 @@ class ServiceImportForm(NetBoxModelImportForm):
class Meta:
model = Service
fields = (
'device', 'virtual_machine', 'ipaddresses', 'name', 'protocol', 'ports', 'description', 'comments', 'tags',
'parent_object_type', 'ipaddresses', 'name', 'protocol', 'ports', 'description', 'comments',
'tags', 'parent_object_id',
)
def clean_ipaddresses(self):
parent = self.cleaned_data.get('device') or self.cleaned_data.get('virtual_machine')
for ip_address in self.cleaned_data['ipaddresses']:
def __init__(self, data=None, *args, **kwargs):
super().__init__(data, *args, **kwargs)
# Limit parent queryset by assigned parent object type
if data:
match data.get('parent_object_type'):
case 'dcim.device':
self.fields['parent'].queryset = Device.objects.all()
case 'ipam.fhrpgroup':
self.fields['parent'].queryset = FHRPGroup.objects.all()
case 'virtualization.virtualmachine':
self.fields['parent'].queryset = VirtualMachine.objects.all()
def save(self, *args, **kwargs):
if (parent := self.cleaned_data.get('parent')):
self.instance.parent = parent
return super().save(*args, **kwargs)
def clean(self):
super().clean()
if (parent := self.cleaned_data.get('parent')):
self.cleaned_data['parent_object_id'] = parent.pk
elif not parent and (parent_id := self.cleaned_data.get('parent_object_id')):
ct = self.cleaned_data.get('parent_object_type')
parent = ct.model_class().objects.filter(id=parent_id).first()
self.cleaned_data['parent'] = parent
for ip_address in self.cleaned_data.get('ipaddresses', []):
if not ip_address.assigned_object or getattr(ip_address.assigned_object, 'parent_object') != parent:
raise forms.ValidationError(
_("{ip} is not assigned to this device/VM.").format(ip=ip_address)
)
return self.cleaned_data['ipaddresses']
return self.cleaned_data

View File

@ -5,11 +5,14 @@ from django.test import override_settings
from django.urls import reverse
from netaddr import IPNetwork
from core.models import ObjectType
from dcim.constants import InterfaceTypeChoices
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site, Interface
from ipam.choices import *
from ipam.models import *
from netbox.choices import CSVDelimiterChoices, ImportFormatChoices
from tenancy.models import Tenant
from users.models import ObjectPermission
from utilities.testing import ViewTestCases, create_tags
@ -1093,10 +1096,10 @@ class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase):
}
cls.csv_data = (
"device,name,protocol,ports,ipaddresses,description",
"Device 1,Service 1,tcp,1,192.0.2.1/24,First service",
"Device 1,Service 2,tcp,2,192.0.2.2/24,Second service",
"Device 1,Service 3,udp,3,,Third service",
"parent_object_type,parent,name,protocol,ports,ipaddresses,description",
"dcim.device,Device 1,Service 1,tcp,1,192.0.2.1/24,First service",
"dcim.device,Device 1,Service 2,tcp,2,192.0.2.2/24,Second service",
"dcim.device,Device 1,Service 3,udp,3,,Third service",
)
cls.csv_update_data = (
@ -1112,6 +1115,66 @@ class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase):
'description': 'New description',
}
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], EXEMPT_EXCLUDE_MODELS=[])
def test_unassigned_ip_addresses(self):
device = Device.objects.first()
addr = IPAddress.objects.create(address='192.0.2.4/24')
csv_data = (
"parent_object_type,parent_object_id,name,protocol,ports,ipaddresses,description",
f"dcim.device,{device.pk},Service 11,tcp,10,{addr.address},Eleventh service",
)
initial_count = self._get_queryset().count()
data = {
'data': '\n'.join(csv_data),
'format': ImportFormatChoices.CSV,
'csv_delimiter': CSVDelimiterChoices.AUTO,
}
# Assign model-level permission
obj_perm = ObjectPermission.objects.create(name='Test permission', actions=['add'])
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Test POST with permission
response = self.client.post(self._get_url('bulk_import'), data)
self.assertHttpStatus(response, 200)
form_errors = response.context['form'].errors
self.assertEqual(len(form_errors), 1)
self.assertIn(addr.address, form_errors['__all__'][0])
self.assertEqual(self._get_queryset().count(), initial_count)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], EXEMPT_EXCLUDE_MODELS=[])
def test_alternate_csv_import(self):
device = Device.objects.first()
interface = device.interfaces.first()
addr = IPAddress.objects.create(assigned_object=interface, address='192.0.2.3/24')
csv_data = (
"parent_object_type,parent_object_id,name,protocol,ports,ipaddresses,description",
f"dcim.device,{device.pk},Service 11,tcp,10,{addr.address},Eleventh service",
)
initial_count = self._get_queryset().count()
data = {
'data': '\n'.join(csv_data),
'format': ImportFormatChoices.CSV,
'csv_delimiter': CSVDelimiterChoices.AUTO,
}
# Assign model-level permission
obj_perm = ObjectPermission.objects.create(name='Test permission', actions=['add'])
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Test POST with permission
response = self.client.post(self._get_url('bulk_import'), data)
if response.status_code != 302:
self.assertEqual(response.context['form'].errors, {}) # debugging aid
self.assertHttpStatus(response, 302)
self.assertEqual(self._get_queryset().count(), initial_count + len(csv_data) - 1)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_create_from_template(self):
self.add_permissions('ipam.add_service')