diff --git a/netbox/netbox/api/viewsets/__init__.py b/netbox/netbox/api/viewsets/__init__.py index 769bdcb26..348bc3550 100644 --- a/netbox/netbox/api/viewsets/__init__.py +++ b/netbox/netbox/api/viewsets/__init__.py @@ -1,17 +1,16 @@ import logging from functools import cached_property -from django.contrib.contenttypes.fields import GenericForeignKey -from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import ObjectDoesNotExist, PermissionDenied from django.db import transaction from django.db.models import ProtectedError, RestrictedError -from django.db.models.fields.related import ManyToOneRel, RelatedField from django_pglocks import advisory_lock from netbox.constants import ADVISORY_LOCK_KEYS from rest_framework import mixins as drf_mixins from rest_framework.response import Response from rest_framework.viewsets import GenericViewSet +from utilities.api import get_prefetches_for_serializer from utilities.exceptions import AbortRequest from . import mixins @@ -47,18 +46,10 @@ class BaseViewSet(GenericViewSet): qs = super().get_queryset() # Dynamically resolve prefetches for included serializer fields and attach them to the queryset - serializer_class = self.get_serializer_class() - model = serializer_class.Meta.model - fields_to_include = self.requested_fields or serializer_class.Meta.fields - prefetch = [] - for field_name in fields_to_include: - try: - field = model._meta.get_field(field_name) - except FieldDoesNotExist: - continue - if isinstance(field, (RelatedField, ManyToOneRel, GenericForeignKey)): - # TODO: Use serializer field source if set, else use its name - prefetch.append(field_name) + prefetch = get_prefetches_for_serializer( + self.get_serializer_class(), + fields_to_include=self.requested_fields + ) if prefetch: qs = qs.prefetch_related(*prefetch) diff --git a/netbox/utilities/api.py b/netbox/utilities/api.py index b53edf53a..b2c1d7eb0 100644 --- a/netbox/utilities/api.py +++ b/netbox/utilities/api.py @@ -2,9 +2,13 @@ import platform import sys from django.conf import settings +from django.contrib.contenttypes.fields import GenericForeignKey +from django.core.exceptions import FieldDoesNotExist +from django.db.models.fields.related import ManyToOneRel, RelatedField from django.http import JsonResponse from django.urls import reverse from rest_framework import status +from rest_framework.serializers import Serializer from rest_framework.utils import formatting from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound @@ -12,6 +16,7 @@ from .utils import dynamic_import __all__ = ( 'get_graphql_type_for_model', + 'get_prefetches_for_serializer', 'get_serializer_for_model', 'get_view_name', 'is_api_request', @@ -89,6 +94,38 @@ def get_view_name(view, suffix=None): return name +def get_prefetches_for_serializer(serializer_class, fields_to_include=None): + """ + Compile and return a list of fields which should be prefetched on the queryset for a serializer. + """ + model = serializer_class.Meta.model + + # If specific fields are not specified, default to all + if not fields_to_include: + fields_to_include = serializer_class.Meta.fields + + prefetch_fields = [] + for field_name in fields_to_include: + + # If the serializer field does not map to a discrete model field, skip it. + try: + field = model._meta.get_field(field_name) + except FieldDoesNotExist: + continue + if isinstance(field, (RelatedField, ManyToOneRel, GenericForeignKey)): + # TODO: Use serializer field source if set, else use its name + prefetch_fields.append(field_name) + + # If this field is represented by a nested serializer, recurse to resolve prefetches + # for the related object. + if serializer_field := serializer_class._declared_fields.get(field_name): + if issubclass(type(serializer_field), Serializer): + for subfield in get_prefetches_for_serializer(type(serializer_field)): + prefetch_fields.append(f'{field_name}__{subfield}') + + return prefetch_fields + + def rest_api_server_error(request, *args, **kwargs): """ Handle exceptions and return a useful error message for REST API requests.