Finish up add FHRP Group assignment to Service

- fixes up ServiceSerializer to support write operations
- fixes up GraphQL components: ServiceType and ServiceFilter
- fixes broken tests
- cleans up lint issues
This commit is contained in:
Jason Novinger 2025-04-07 08:23:27 -05:00
parent 60e8268882
commit 0564ee9cfc
11 changed files with 46 additions and 86 deletions

View File

@ -613,7 +613,7 @@ class Device(
to='ipam.Service', to='ipam.Service',
content_type_field='parent_object_type', content_type_field='parent_object_type',
object_id_field='parent_object_id', object_id_field='parent_object_id',
related_query_name='devices', related_query_name='device',
) )
# Counter fields # Counter fields

View File

@ -1,13 +1,13 @@
from django.contrib.contenttypes.models import ContentType
from drf_spectacular.utils import extend_schema_field from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers from rest_framework import serializers
from dcim.models import Device
from ipam.choices import * from ipam.choices import *
from ipam.models import IPAddress, FHRPGroup, Service, ServiceTemplate from ipam.constants import SERVICE_ASSIGNMENT_MODELS
from netbox.api.fields import ChoiceField, SerializedPKRelatedField from ipam.models import IPAddress, Service, ServiceTemplate
from netbox.api.fields import ChoiceField, ContentTypeField, SerializedPKRelatedField
from netbox.api.serializers import NetBoxModelSerializer from netbox.api.serializers import NetBoxModelSerializer
from utilities.api import get_serializer_for_model from utilities.api import get_serializer_for_model
from virtualization.models import VirtualMachine
from .ip import IPAddressSerializer from .ip import IPAddressSerializer
__all__ = ( __all__ = (
@ -29,9 +29,6 @@ class ServiceTemplateSerializer(NetBoxModelSerializer):
class ServiceSerializer(NetBoxModelSerializer): class ServiceSerializer(NetBoxModelSerializer):
device = serializers.SerializerMethodField(read_only=True)
virtual_machine = serializers.SerializerMethodField(read_only=True)
fhrp_group = serializers.SerializerMethodField(read_only=True)
protocol = ChoiceField(choices=ServiceProtocolChoices, required=False) protocol = ChoiceField(choices=ServiceProtocolChoices, required=False)
ipaddresses = SerializedPKRelatedField( ipaddresses = SerializedPKRelatedField(
queryset=IPAddress.objects.all(), queryset=IPAddress.objects.all(),
@ -40,11 +37,15 @@ class ServiceSerializer(NetBoxModelSerializer):
required=False, required=False,
many=True many=True
) )
parent_object_type = ContentTypeField(
queryset=ContentType.objects.filter(SERVICE_ASSIGNMENT_MODELS)
)
parent = serializers.SerializerMethodField(read_only=True)
class Meta: class Meta:
model = Service model = Service
fields = [ fields = [
'id', 'url', 'display_url', 'display', 'device', 'virtual_machine', 'fhrp_group', 'name', 'id', 'url', 'display_url', 'display', 'parent_object_type', 'parent_object_id', 'parent', 'name',
'protocol', 'ports', 'ipaddresses', 'description', 'comments', 'tags', 'custom_fields', 'protocol', 'ports', 'ipaddresses', 'description', 'comments', 'tags', 'custom_fields',
'created', 'last_updated', 'created', 'last_updated',
] ]
@ -57,21 +58,3 @@ class ServiceSerializer(NetBoxModelSerializer):
serializer = get_serializer_for_model(obj.parent) serializer = get_serializer_for_model(obj.parent)
context = {'request': self.context['request']} context = {'request': self.context['request']}
return serializer(obj.parent, nested=True, context=context).data return serializer(obj.parent, nested=True, context=context).data
@extend_schema_field(serializers.JSONField(allow_null=True))
def get_device(self, obj):
if isinstance(obj.parent, Device):
return self.get_parent(obj)
return None
@extend_schema_field(serializers.JSONField(allow_null=True))
def get_virtual_machine(self, obj):
if isinstance(obj.parent, VirtualMachine):
return self.get_parent(obj)
return None
@extend_schema_field(serializers.JSONField(allow_null=True))
def get_fhrp_group(self, obj):
if isinstance(obj.parent, FHRPGroup):
return self.get_parent(obj)
return None

View File

@ -1171,12 +1171,12 @@ class ServiceFilterSet(NetBoxModelFilterSet):
field_name='pk', field_name='pk',
label=_('Virtual machine (ID)'), label=_('Virtual machine (ID)'),
) )
fhrp_group = MultiValueCharFilter( fhrpgroup = MultiValueCharFilter(
method='filter_fhrp_group', method='filter_fhrp_group',
field_name='name', field_name='name',
label=_('FHRP Group (name)'), label=_('FHRP Group (name)'),
) )
fhrp_group_id = MultiValueNumberFilter( fhrpgroup_id = MultiValueNumberFilter(
method='filter_fhrp_group', method='filter_fhrp_group',
field_name='pk', field_name='pk',
label=_('FHRP Group (ID)'), label=_('FHRP Group (ID)'),
@ -1199,7 +1199,7 @@ class ServiceFilterSet(NetBoxModelFilterSet):
class Meta: class Meta:
model = Service model = Service
fields = ('id', 'name', 'protocol', 'description') fields = ('id', 'name', 'protocol', 'description', 'parent_object_type', 'parent_object_id')
def search(self, queryset, name, value): def search(self, queryset, name, value):
if not value.strip(): if not value.strip():

View File

@ -19,8 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup
from core.graphql.filters import ContentTypeFilter from core.graphql.filters import ContentTypeFilter
from dcim.graphql.filters import DeviceFilter, SiteFilter from dcim.graphql.filters import SiteFilter
from virtualization.graphql.filters import VirtualMachineFilter
from vpn.graphql.filters import L2VPNFilter from vpn.graphql.filters import L2VPNFilter
from .enums import * from .enums import *
@ -216,16 +215,14 @@ class RouteTargetFilter(TenancyFilterMixin, PrimaryModelFilterMixin):
@strawberry_django.filter(models.Service, lookups=True) @strawberry_django.filter(models.Service, lookups=True)
class ServiceFilter(ContactFilterMixin, ServiceBaseFilterMixin, PrimaryModelFilterMixin): class ServiceFilter(ContactFilterMixin, ServiceBaseFilterMixin, PrimaryModelFilterMixin):
device: Annotated['DeviceFilter', strawberry.lazy('dcim.graphql.filters')] | None = strawberry_django.filter_field()
device_id: ID | None = strawberry_django.filter_field()
virtual_machine: Annotated['VirtualMachineFilter', strawberry.lazy('virtualization.graphql.filters')] | None = (
strawberry_django.filter_field()
)
virtual_machine_id: ID | None = strawberry_django.filter_field()
name: FilterLookup[str] | None = strawberry_django.filter_field() name: FilterLookup[str] | None = strawberry_django.filter_field()
ipaddresses: Annotated['IPAddressFilter', strawberry.lazy('ipam.graphql.filters')] | None = ( ipaddresses: Annotated['IPAddressFilter', strawberry.lazy('ipam.graphql.filters')] | None = (
strawberry_django.filter_field() strawberry_django.filter_field()
) )
parent_object_type: Annotated['ContentTypeFilter', strawberry.lazy('core.graphql.filters')] | None = (
strawberry_django.filter_field()
)
parent_object_id: ID | None = strawberry_django.filter_field()
@strawberry_django.filter(models.ServiceTemplate, lookups=True) @strawberry_django.filter(models.ServiceTemplate, lookups=True)

View File

@ -5,12 +5,10 @@ import strawberry_django
from circuits.graphql.types import ProviderType from circuits.graphql.types import ProviderType
from dcim.graphql.types import SiteType from dcim.graphql.types import SiteType
from dcim.models import Device
from extras.graphql.mixins import ContactsMixin from extras.graphql.mixins import ContactsMixin
from ipam import models from ipam import models
from netbox.graphql.scalars import BigInt from netbox.graphql.scalars import BigInt
from netbox.graphql.types import BaseObjectType, NetBoxObjectType, OrganizationalObjectType from netbox.graphql.types import BaseObjectType, NetBoxObjectType, OrganizationalObjectType
from virtualization.models import VirtualMachine
from .filters import * from .filters import *
from .mixins import IPAddressesMixin from .mixins import IPAddressesMixin
@ -243,41 +241,14 @@ class RouteTargetType(NetBoxObjectType):
@strawberry_django.type( @strawberry_django.type(
models.Service, models.Service,
fields='__all__', exclude=('parent_object_type', 'parent_object_id'),
filters=ServiceFilter, filters=ServiceFilter,
pagination=True pagination=True
) )
class ServiceType(NetBoxObjectType, ContactsMixin): class ServiceType(NetBoxObjectType, ContactsMixin):
ports: List[int] ports: List[int]
# device: Annotated["DeviceType", strawberry.lazy('dcim.graphql.types')] | None
# virtual_machine: Annotated["VirtualMachineType", strawberry.lazy('virtualization.graphql.types')] | None
# fhrp_group: Annotated["FHRPGroupType", strawberry.lazy('ipam.graphql.types')] | None
ipaddresses: List[Annotated["IPAddressType", strawberry.lazy('ipam.graphql.types')]] ipaddresses: List[Annotated["IPAddressType", strawberry.lazy('ipam.graphql.types')]]
@strawberry_django.field
def device(self) -> Annotated[Union[
Annotated["DeviceType", strawberry.lazy('dcim.graphql.types')],
], strawberry.union("ServiceAssignmentType")] | None:
if isinstance(self.parent, Device):
return self.parent
return None
@strawberry_django.field
def virtual_machine(self) -> Annotated[Union[
Annotated["VirtualMachineType", strawberry.lazy('virtualization.graphql.types')],
], strawberry.union("ServiceAssignmentType")] | None:
if isinstance(self.parent, VirtualMachine):
return self.parent
return None
@strawberry_django.field
def fhrp_group(self) -> Annotated[Union[
Annotated["FHRPGroupType", strawberry.lazy('ipam.graphql.types')],
], strawberry.union("ServiceAssignmentType")] | None:
if isinstance(self.parent, models.FHRPGroup):
return self.parent
return None
@strawberry_django.field @strawberry_django.field
def parent(self) -> Annotated[Union[ def parent(self) -> Annotated[Union[
Annotated["DeviceType", strawberry.lazy('dcim.graphql.types')], Annotated["DeviceType", strawberry.lazy('dcim.graphql.types')],

View File

@ -123,7 +123,7 @@ class ServiceIndex(SearchIndex):
('description', 500), ('description', 500),
('comments', 5000), ('comments', 5000),
) )
display_attrs = ('device', 'virtual_machine', 'description') display_attrs = ('parent', 'description')
@register_search @register_search

View File

@ -1198,27 +1198,30 @@ class ServiceTest(APIViewTestCases.APIViewTestCase):
Device.objects.bulk_create(devices) Device.objects.bulk_create(devices)
services = ( services = (
Service(device=devices[0], name='Service 1', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]), Service(parent=devices[0], name='Service 1', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
Service(device=devices[0], name='Service 2', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[2]), Service(parent=devices[0], name='Service 2', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[2]),
Service(device=devices[0], name='Service 3', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[3]), Service(parent=devices[0], name='Service 3', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[3]),
) )
Service.objects.bulk_create(services) Service.objects.bulk_create(services)
cls.create_data = [ cls.create_data = [
{ {
'device': devices[1].pk, 'parent_object_id': devices[1].pk,
'parent_object_type': 'dcim.device',
'name': 'Service 4', 'name': 'Service 4',
'protocol': ServiceProtocolChoices.PROTOCOL_TCP, 'protocol': ServiceProtocolChoices.PROTOCOL_TCP,
'ports': [4], 'ports': [4],
}, },
{ {
'device': devices[1].pk, 'parent_object_id': devices[1].pk,
'parent_object_type': 'dcim.device',
'name': 'Service 5', 'name': 'Service 5',
'protocol': ServiceProtocolChoices.PROTOCOL_TCP, 'protocol': ServiceProtocolChoices.PROTOCOL_TCP,
'ports': [5], 'ports': [5],
}, },
{ {
'device': devices[1].pk, 'parent_object_id': devices[1].pk,
'parent_object_type': 'dcim.device',
'name': 'Service 6', 'name': 'Service 6',
'protocol': ServiceProtocolChoices.PROTOCOL_TCP, 'protocol': ServiceProtocolChoices.PROTOCOL_TCP,
'ports': [6], 'ports': [6],

View File

@ -2332,34 +2332,39 @@ class ServiceTestCase(TestCase, ChangeLoggedFilterSetTests):
services = ( services = (
Service( Service(
device=devices[0], parent=devices[0],
name='Service 1', name='Service 1',
protocol=ServiceProtocolChoices.PROTOCOL_TCP, protocol=ServiceProtocolChoices.PROTOCOL_TCP,
ports=[1001], ports=[1001],
description='foobar1', description='foobar1',
), ),
Service( Service(
device=devices[1], parent=devices[1],
name='Service 2', name='Service 2',
protocol=ServiceProtocolChoices.PROTOCOL_TCP, protocol=ServiceProtocolChoices.PROTOCOL_TCP,
ports=[1002], ports=[1002],
description='foobar2', description='foobar2',
), ),
Service(device=devices[2], name='Service 3', protocol=ServiceProtocolChoices.PROTOCOL_UDP, ports=[1003]),
Service( Service(
virtual_machine=virtual_machines[0], parent=devices[2],
name='Service 3',
protocol=ServiceProtocolChoices.PROTOCOL_UDP,
ports=[1003]
),
Service(
parent=virtual_machines[0],
name='Service 4', name='Service 4',
protocol=ServiceProtocolChoices.PROTOCOL_TCP, protocol=ServiceProtocolChoices.PROTOCOL_TCP,
ports=[2001], ports=[2001],
), ),
Service( Service(
virtual_machine=virtual_machines[1], parent=virtual_machines[1],
name='Service 5', name='Service 5',
protocol=ServiceProtocolChoices.PROTOCOL_TCP, protocol=ServiceProtocolChoices.PROTOCOL_TCP,
ports=[2002], ports=[2002],
), ),
Service( Service(
virtual_machine=virtual_machines[2], parent=virtual_machines[2],
name='Service 6', name='Service 6',
protocol=ServiceProtocolChoices.PROTOCOL_UDP, protocol=ServiceProtocolChoices.PROTOCOL_UDP,
ports=[2003], ports=[2003],

View File

@ -1053,6 +1053,8 @@ class ServiceTemplateTestCase(ViewTestCases.PrimaryObjectViewTestCase):
class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase): class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase):
model = Service model = Service
# TODO, related to #9816, cannot validate GFK
validation_excluded_fields = ('device',)
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -1081,7 +1083,6 @@ class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase):
cls.form_data = { cls.form_data = {
'device': device.pk, 'device': device.pk,
'virtual_machine': None,
'name': 'Service X', 'name': 'Service X',
'protocol': ServiceProtocolChoices.PROTOCOL_TCP, 'protocol': ServiceProtocolChoices.PROTOCOL_TCP,
'ports': '104,105', 'ports': '104,105',
@ -1125,7 +1126,7 @@ class ServiceTestCase(ViewTestCases.PrimaryObjectViewTestCase):
request = { request = {
'path': self._get_url('add'), 'path': self._get_url('add'),
'data': { 'data': {
'parent': device.pk, 'device': device.pk,
'service_template': service_template.pk, 'service_template': service_template.pk,
}, },
} }

View File

@ -1445,7 +1445,7 @@ class ServiceBulkImportView(generic.BulkImportView):
@register_model_view(Service, 'bulk_edit', path='edit', detail=False) @register_model_view(Service, 'bulk_edit', path='edit', detail=False)
class ServiceBulkEditView(generic.BulkEditView): class ServiceBulkEditView(generic.BulkEditView):
queryset = Service.objects.prefetch_related('device', 'virtual_machine') queryset = Service.objects.prefetch_related('parent')
filterset = filtersets.ServiceFilterSet filterset = filtersets.ServiceFilterSet
table = tables.ServiceTable table = tables.ServiceTable
form = forms.ServiceBulkEditForm form = forms.ServiceBulkEditForm
@ -1453,6 +1453,6 @@ class ServiceBulkEditView(generic.BulkEditView):
@register_model_view(Service, 'bulk_delete', path='delete', detail=False) @register_model_view(Service, 'bulk_delete', path='delete', detail=False)
class ServiceBulkDeleteView(generic.BulkDeleteView): class ServiceBulkDeleteView(generic.BulkDeleteView):
queryset = Service.objects.prefetch_related('device', 'virtual_machine') queryset = Service.objects.prefetch_related('parent')
filterset = filtersets.ServiceFilterSet filterset = filtersets.ServiceFilterSet
table = tables.ServiceTable table = tables.ServiceTable

View File

@ -130,7 +130,7 @@ class VirtualMachine(ContactsMixin, ImageAttachmentsMixin, RenderConfigMixin, Co
to='ipam.Service', to='ipam.Service',
content_type_field='parent_object_type', content_type_field='parent_object_type',
object_id_field='parent_object_id', object_id_field='parent_object_id',
related_query_name='virtualmachines', related_query_name='virtual_machine',
) )
# Counter fields # Counter fields