diff --git a/netbox/dcim/tests/test_models.py b/netbox/dcim/tests/test_models.py index 148b9e35f..fb075989a 100644 --- a/netbox/dcim/tests/test_models.py +++ b/netbox/dcim/tests/test_models.py @@ -619,6 +619,49 @@ class DeviceTestCase(TestCase): with self.assertRaises(ValidationError): Device(name='device1', site=sites[0], device_type=device_type, role=device_role, cluster=clusters[1]).full_clean() + def test_module_bay_recursion(self): + site = Site.objects.create(name='Site 1', slug='site-1') + location = Location.objects.create(name='Location 1', slug='location-1', site=site) + rack = Rack.objects.create(name='Rack 1', site=site) + device_type = DeviceType.objects.first() + device_role = DeviceRole.objects.first() + device = Device.objects.create(name='Device 1', device_type=device_type, role=device_role, site=site, location=location, rack=rack) + + module_bays = ( + ModuleBay(device=device, name='Module Bay 1', label='A', description='First'), + ModuleBay(device=device, name='Module Bay 2', label='B', description='Second'), + ModuleBay(device=device, name='Module Bay 2', label='B', description='Second'), + ) + ModuleBay.objects.bulk_create(module_bays) + + manufacturer = Manufacturer.objects.create(name='Manufacturer 1', slug='manufacturer-1') + module_type = ModuleType.objects.create(manufacturer=manufacturer, model='Module Type 1') + modules = ( + Module(device=device, module_bay=module_bays[0], module_type=module_type), + Module(device=device, module_bay=module_bays[1], module_type=module_type), + Module(device=device, module_bay=module_bays[2], module_type=module_type), + ) + # M2 -> MB2 -> M1 -> MB1 -> M0 -> MB0 + Module.objects.bulk_create(modules) + module_bays[1].module = modules[0] + module_bays[1].clean() + module_bays[1].save() + module_bays[2].module = modules[1] + module_bays[2].clean() + module_bays[2].save() + + # Confirm error if ModuleBay recurses + with self.assertRaises(ValidationError): + module_bays[0].module = modules[2] + module_bays[0].clean() + module_bays[0].save() + + # Confirm error if Module recurses + with self.assertRaises(ValidationError): + modules[0].module_bay = module_bays[2] + modules[0].clean() + modules[0].save() + class CableTestCase(TestCase):