fix(api): Fix schema and field definitions for OpenAPI

Add `get_internal_type()` to custom field classes for Django compatibility,
annotate path parameters and operation IDs for background endpoints, and
provide serializer context on the RQ base viewset to clear schema warnings.

Fixes #20365
This commit is contained in:
Martin Hauser 2025-09-30 11:31:38 +02:00 committed by Jeremy Stretch
parent 10e76597a8
commit 9e75a2f955
6 changed files with 81 additions and 34 deletions

View File

@ -13,7 +13,7 @@ class BackgroundTaskSerializer(serializers.Serializer):
url = serializers.HyperlinkedIdentityField( url = serializers.HyperlinkedIdentityField(
view_name='core-api:rqtask-detail', view_name='core-api:rqtask-detail',
lookup_field='id', lookup_field='id',
lookup_url_kwarg='pk' lookup_url_kwarg='id'
) )
description = serializers.CharField() description = serializers.CharField()
origin = serializers.CharField() origin = serializers.CharField()

View File

@ -5,7 +5,7 @@ from django_rq.queues import get_redis_connection
from django_rq.settings import QUEUES_LIST from django_rq.settings import QUEUES_LIST
from django_rq.utils import get_statistics from django_rq.utils import get_statistics
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied from rest_framework.exceptions import PermissionDenied
@ -24,6 +24,7 @@ from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.metadata import ContentTypeMetadata from netbox.api.metadata import ContentTypeMetadata
from netbox.api.pagination import LimitOffsetListPagination from netbox.api.pagination import LimitOffsetListPagination
from netbox.api.viewsets import NetBoxModelViewSet, NetBoxReadOnlyModelViewSet from netbox.api.viewsets import NetBoxModelViewSet, NetBoxReadOnlyModelViewSet
from . import serializers from . import serializers
@ -117,29 +118,49 @@ class BaseRQViewSet(viewsets.ViewSet):
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
""" """
Return the serializer instance that should be used for validating and Return the serializer instance that should be used for validating and
deserializing input, and for serializing output. deserializing input and for serializing output.
""" """
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
kwargs['context'] = self.get_serializer_context() kwargs['context'] = self.get_serializer_context()
return serializer_class(*args, **kwargs) return serializer_class(*args, **kwargs)
def get_serializer_class(self):
"""
Return the class to use for the serializer.
"""
return self.serializer_class
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
"""
return {
'request': self.request,
'format': self.format_kwarg,
'view': self,
}
class BackgroundQueueViewSet(BaseRQViewSet): class BackgroundQueueViewSet(BaseRQViewSet):
""" """
Retrieve a list of RQ Queues. Retrieve a list of RQ Queues.
Note: Queue names are not URL safe so not returning a detail view. Note: Queue names are not URL safe, so not returning a detail view.
""" """
serializer_class = serializers.BackgroundQueueSerializer serializer_class = serializers.BackgroundQueueSerializer
lookup_field = 'name' lookup_field = 'name'
lookup_value_regex = r'[\w.@+-]+' lookup_value_regex = r'[\w.@+-]+'
def get_view_name(self): def get_view_name(self):
return "Background Queues" return 'Background Queues'
def get_data(self): def get_data(self):
return get_statistics(run_maintenance_tasks=True)["queues"] return get_statistics(run_maintenance_tasks=True)['queues']
@extend_schema(responses={200: OpenApiTypes.OBJECT}) @extend_schema(
operation_id='core_background_queues_retrieve_by_name',
parameters=[OpenApiParameter(name='name', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)],
responses={200: OpenApiTypes.OBJECT},
)
def retrieve(self, request, name): def retrieve(self, request, name):
data = self.get_data() data = self.get_data()
if not data: if not data:
@ -161,12 +182,17 @@ class BackgroundWorkerViewSet(BaseRQViewSet):
lookup_field = 'name' lookup_field = 'name'
def get_view_name(self): def get_view_name(self):
return "Background Workers" return 'Background Workers'
def get_data(self): def get_data(self):
config = QUEUES_LIST[0] config = QUEUES_LIST[0]
return Worker.all(get_redis_connection(config['connection_config'])) return Worker.all(get_redis_connection(config['connection_config']))
@extend_schema(
operation_id='core_background_workers_retrieve_by_name',
parameters=[OpenApiParameter(name='name', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)],
responses={200: OpenApiTypes.OBJECT},
)
def retrieve(self, request, name): def retrieve(self, request, name):
# all the RQ queues should use the same connection # all the RQ queues should use the same connection
config = QUEUES_LIST[0] config = QUEUES_LIST[0]
@ -184,9 +210,10 @@ class BackgroundTaskViewSet(BaseRQViewSet):
Retrieve a list of RQ Tasks. Retrieve a list of RQ Tasks.
""" """
serializer_class = serializers.BackgroundTaskSerializer serializer_class = serializers.BackgroundTaskSerializer
lookup_field = 'id'
def get_view_name(self): def get_view_name(self):
return "Background Tasks" return 'Background Tasks'
def get_data(self): def get_data(self):
return get_rq_jobs() return get_rq_jobs()
@ -199,45 +226,53 @@ class BackgroundTaskViewSet(BaseRQViewSet):
return task return task
@extend_schema(responses={200: OpenApiTypes.OBJECT}) @extend_schema(
def retrieve(self, request, pk): operation_id='core_background_tasks_retrieve_by_id',
parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)],
responses={200: OpenApiTypes.OBJECT},
)
def retrieve(self, request, id):
""" """
Retrieve the details of the specified RQ Task. Retrieve the details of the specified RQ Task.
""" """
task = self.get_task_from_id(pk) task = self.get_task_from_id(id)
serializer = self.serializer_class(task, context={'request': request}) serializer = self.serializer_class(task, context={'request': request})
return Response(serializer.data) return Response(serializer.data)
@action(methods=["POST"], detail=True) @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
def delete(self, request, pk): @action(methods=['POST'], detail=True)
def delete(self, request, id):
""" """
Delete the specified RQ Task. Delete the specified RQ Task.
""" """
delete_rq_job(pk) delete_rq_job(id)
return HttpResponse(status=200) return HttpResponse(status=200)
@action(methods=["POST"], detail=True) @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
def requeue(self, request, pk): @action(methods=['POST'], detail=True)
def requeue(self, request, id):
""" """
Requeues the specified RQ Task. Requeues the specified RQ Task.
""" """
requeue_rq_job(pk) requeue_rq_job(id)
return HttpResponse(status=200) return HttpResponse(status=200)
@action(methods=["POST"], detail=True) @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
def enqueue(self, request, pk): @action(methods=['POST'], detail=True)
def enqueue(self, request, id):
""" """
Enqueues the specified RQ Task. Enqueues the specified RQ Task.
""" """
enqueue_rq_job(pk) enqueue_rq_job(id)
return HttpResponse(status=200) return HttpResponse(status=200)
@action(methods=["POST"], detail=True) @extend_schema(parameters=[OpenApiParameter(name='id', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)])
def stop(self, request, pk): @action(methods=['POST'], detail=True)
def stop(self, request, id):
""" """
Stops the specified RQ Task. Stops the specified RQ Task.
""" """
stopped_jobs = stop_rq_job(pk) stopped_jobs = stop_rq_job(id)
if len(stopped_jobs) == 1: if len(stopped_jobs) == 1:
return HttpResponse(status=200) return HttpResponse(status=200)
else: else:

View File

@ -26,7 +26,7 @@ class eui64_unix_expanded_uppercase(eui64_unix_expanded):
# #
class MACAddressField(models.Field): class MACAddressField(models.Field):
description = "PostgreSQL MAC Address field" description = 'PostgreSQL MAC Address field'
def python_type(self): def python_type(self):
return EUI return EUI
@ -34,6 +34,9 @@ class MACAddressField(models.Field):
def from_db_value(self, value, expression, connection): def from_db_value(self, value, expression, connection):
return self.to_python(value) return self.to_python(value)
def get_internal_type(self):
return 'CharField'
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
return value return value
@ -54,7 +57,7 @@ class MACAddressField(models.Field):
class WWNField(models.Field): class WWNField(models.Field):
description = "World Wide Name field" description = 'World Wide Name field'
def python_type(self): def python_type(self):
return EUI return EUI
@ -62,6 +65,9 @@ class WWNField(models.Field):
def from_db_value(self, value, expression, connection): def from_db_value(self, value, expression, connection):
return self.to_python(value) return self.to_python(value)
def get_internal_type(self):
return 'CharField'
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
return value return value

View File

@ -26,6 +26,7 @@ class CustomFieldChoiceSetSerializer(ChangeLogMessageSerializer, ValidatedModelS
max_length=2 max_length=2
) )
) )
choices_count = serializers.IntegerField(read_only=True)
class Meta: class Meta:
model = CustomFieldChoiceSet model = CustomFieldChoiceSet

View File

@ -26,6 +26,9 @@ class BaseIPField(models.Field):
def from_db_value(self, value, expression, connection): def from_db_value(self, value, expression, connection):
return self.to_python(value) return self.to_python(value)
def get_internal_type(self):
return 'CharField'
def to_python(self, value): def to_python(self, value):
if not value: if not value:
return value return value
@ -57,7 +60,7 @@ class IPNetworkField(BaseIPField):
""" """
IP prefix (network and mask) IP prefix (network and mask)
""" """
description = "PostgreSQL CIDR field" description = 'PostgreSQL CIDR field'
default_validators = [validators.prefix_validator] default_validators = [validators.prefix_validator]
def db_type(self, connection): def db_type(self, connection):
@ -83,7 +86,7 @@ class IPAddressField(BaseIPField):
""" """
IP address (host address and mask) IP address (host address and mask)
""" """
description = "PostgreSQL INET field" description = 'PostgreSQL INET field'
def db_type(self, connection): def db_type(self, connection):
return 'inet' return 'inet'
@ -110,7 +113,7 @@ IPAddressField.register_lookup(lookups.Inet)
class ASNField(models.BigIntegerField): class ASNField(models.BigIntegerField):
description = "32-bit ASN field" description = '32-bit ASN field'
default_validators = [ default_validators = [
MinValueValidator(BGP_ASN_MIN), MinValueValidator(BGP_ASN_MIN),
MaxValueValidator(BGP_ASN_MAX), MaxValueValidator(BGP_ASN_MAX),

View File

@ -354,13 +354,13 @@ class PrefixFilterSet(NetBoxModelFilterSet, ScopedFilterSet, TenancyFilterSet, C
vlan_group_id = django_filters.ModelMultipleChoiceFilter( vlan_group_id = django_filters.ModelMultipleChoiceFilter(
field_name='vlan__group', field_name='vlan__group',
queryset=VLANGroup.objects.all(), queryset=VLANGroup.objects.all(),
to_field_name="id", to_field_name='id',
label=_('VLAN Group (ID)'), label=_('VLAN Group (ID)'),
) )
vlan_group = django_filters.ModelMultipleChoiceFilter( vlan_group = django_filters.ModelMultipleChoiceFilter(
field_name='vlan__group__slug', field_name='vlan__group__slug',
queryset=VLANGroup.objects.all(), queryset=VLANGroup.objects.all(),
to_field_name="slug", to_field_name='slug',
label=_('VLAN Group (slug)'), label=_('VLAN Group (slug)'),
) )
vlan_id = django_filters.ModelMultipleChoiceFilter( vlan_id = django_filters.ModelMultipleChoiceFilter(
@ -695,12 +695,12 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet, ContactModelFil
return queryset.filter(q) return queryset.filter(q)
def parse_inet_addresses(self, value): def parse_inet_addresses(self, value):
''' """
Parse networks or IP addresses and cast to a format Parse networks or IP addresses and cast to a format
acceptable by the Postgres inet type. acceptable by the Postgres inet type.
Skips invalid values. Skips invalid values.
''' """
parsed = [] parsed = []
for addr in value: for addr in value:
if netaddr.valid_ipv4(addr) or netaddr.valid_ipv6(addr): if netaddr.valid_ipv4(addr) or netaddr.valid_ipv6(addr):
@ -718,7 +718,7 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet, ContactModelFil
# as argument. If they are all invalid, # as argument. If they are all invalid,
# we return an empty queryset # we return an empty queryset
value = self.parse_inet_addresses(value) value = self.parse_inet_addresses(value)
if (len(value) == 0): if len(value) == 0:
return queryset.none() return queryset.none()
try: try:
@ -1079,6 +1079,7 @@ class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
def get_for_virtualmachine(self, queryset, name, value): def get_for_virtualmachine(self, queryset, name, value):
return queryset.get_for_virtualmachine(value) return queryset.get_for_virtualmachine(value)
@extend_schema_field(OpenApiTypes.INT)
def filter_interface_id(self, queryset, name, value): def filter_interface_id(self, queryset, name, value):
if value is None: if value is None:
return queryset.none() return queryset.none()
@ -1087,6 +1088,7 @@ class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
Q(interfaces_as_untagged=value) Q(interfaces_as_untagged=value)
).distinct() ).distinct()
@extend_schema_field(OpenApiTypes.INT)
def filter_vminterface_id(self, queryset, name, value): def filter_vminterface_id(self, queryset, name, value):
if value is None: if value is None:
return queryset.none() return queryset.none()