diff --git a/netbox/virtualization/api/serializers_/virtualmachines.py b/netbox/virtualization/api/serializers_/virtualmachines.py index 05fa2e427..ed14b0a29 100644 --- a/netbox/virtualization/api/serializers_/virtualmachines.py +++ b/netbox/virtualization/api/serializers_/virtualmachines.py @@ -112,15 +112,32 @@ class VMInterfaceSerializer(NetBoxModelSerializer): brief_fields = ('id', 'url', 'display', 'virtual_machine', 'name', 'description') def validate(self, data): - # Validate many-to-many VLAN assignments - virtual_machine = self.instance.virtual_machine if self.instance else data.get('virtual_machine') - for vlan in data.get('tagged_vlans', []): - if vlan.site not in [virtual_machine.site, None]: - raise serializers.ValidationError({ - 'tagged_vlans': f"VLAN {vlan} must belong to the same site as the interface's parent virtual " - f"machine, or it must be global." - }) + virtual_machine = None + tagged_vlans = [] + + # #18887 + # There seem to be multiple code paths coming through here. Previously, we might either get + # the VirtualMachine instance from self.instance or from incoming data. However, #18887 + # illustrated that this is also being called when a custom field pointing to an object_type + # of VMInterface is on the right side of a custom-field assignment coming in from an API + # request. As such, we need to check a third way to access the VirtualMachine + # instance--where `data` is the VMInterface instance itself and we can get the associated + # VirtualMachine via attribute access. + if isinstance(data, dict): + virtual_machine = self.instance.virtual_machine if self.instance else data.get('virtual_machine') + tagged_vlans = data.get('tagged_vlans', []) + elif isinstance(data, VMInterface): + virtual_machine = data.virtual_machine + tagged_vlans = data.tagged_vlans.all() + + if virtual_machine: + for vlan in tagged_vlans: + if vlan.site not in [virtual_machine.site, None]: + raise serializers.ValidationError({ + 'tagged_vlans': f"VLAN {vlan} must belong to the same site as the interface's parent virtual " + f"machine, or it must be global." + }) return super().validate(data) diff --git a/netbox/virtualization/tests/test_api.py b/netbox/virtualization/tests/test_api.py index c57b57f2e..dfa8309a0 100644 --- a/netbox/virtualization/tests/test_api.py +++ b/netbox/virtualization/tests/test_api.py @@ -1,11 +1,15 @@ +from django.test import tag from django.urls import reverse +from netaddr import IPNetwork from rest_framework import status +from core.models import ObjectType from dcim.choices import InterfaceModeChoices from dcim.models import Site -from extras.models import ConfigTemplate +from extras.choices import CustomFieldTypeChoices +from extras.models import ConfigTemplate, CustomField from ipam.choices import VLANQinQRoleChoices -from ipam.models import VLAN, VRF +from ipam.models import Prefix, VLAN, VRF from utilities.testing import APITestCase, APIViewTestCases, create_test_device, create_test_virtualmachine from virtualization.choices import * from virtualization.models import * @@ -350,6 +354,39 @@ class VMInterfaceTest(APIViewTestCases.APIViewTestCase): }, ] + @tag('regression') + def test_set_vminterface_as_object_in_custom_field(self): + cf = CustomField.objects.create( + name='associated_interface', + type=CustomFieldTypeChoices.TYPE_OBJECT, + related_object_type=ObjectType.objects.get_for_model(VMInterface), + required=False + ) + cf.object_types.set([ObjectType.objects.get_for_model(Prefix)]) + cf.save() + + prefix = Prefix.objects.create(prefix=IPNetwork('10.0.0.0/12')) + vmi = VMInterface.objects.first() + + url = reverse('ipam-api:prefix-detail', kwargs={'pk': prefix.pk}) + data = { + 'custom_fields': { + 'associated_interface': vmi.id, + }, + } + + self.add_permissions('ipam.change_prefix') + + response = self.client.patch(url, data, format='json', **self.header) + self.assertEqual(response.status_code, 200) + + prefix_data = response.json() + self.assertEqual(prefix_data['custom_fields']['associated_interface']['id'], vmi.id) + + reloaded_prefix = Prefix.objects.get(pk=prefix.pk) + self.assertEqual(prefix.pk, reloaded_prefix.pk) + self.assertNotEqual(reloaded_prefix.cf['associated_interface'], None) + def test_bulk_delete_child_interfaces(self): interface1 = VMInterface.objects.get(name='Interface 1') virtual_machine = interface1.virtual_machine