From afb5f42af33c34d54829821622336298a9d28981 Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Wed, 14 Feb 2024 10:06:57 -0500 Subject: [PATCH] Introduce get_annotations_for_serializer() and enable dynamic annotations --- netbox/netbox/api/viewsets/__init__.py | 13 +++++++------ netbox/utilities/api.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/netbox/netbox/api/viewsets/__init__.py b/netbox/netbox/api/viewsets/__init__.py index e9edf9311..2ff764bed 100644 --- a/netbox/netbox/api/viewsets/__init__.py +++ b/netbox/netbox/api/viewsets/__init__.py @@ -10,7 +10,7 @@ from rest_framework import mixins as drf_mixins from rest_framework.response import Response from rest_framework.viewsets import GenericViewSet -from utilities.api import get_prefetches_for_serializer +from utilities.api import get_annotations_for_serializer, get_prefetches_for_serializer from utilities.exceptions import AbortRequest from . import mixins @@ -44,15 +44,16 @@ class BaseViewSet(GenericViewSet): def get_queryset(self): qs = super().get_queryset() + serializer_class = self.get_serializer_class() # Dynamically resolve prefetches for included serializer fields and attach them to the queryset - prefetch = get_prefetches_for_serializer( - self.get_serializer_class(), - fields_to_include=self.requested_fields - ) - if prefetch: + if prefetch := get_prefetches_for_serializer(serializer_class, fields_to_include=self.requested_fields): qs = qs.prefetch_related(*prefetch) + # Dynamically resolve annotations for RelatedObjectCountFields on the serializer and attach them to the queryset + if annotations := get_annotations_for_serializer(serializer_class, fields_to_include=self.requested_fields): + qs = qs.annotate(**annotations) + return qs def get_serializer(self, *args, **kwargs): diff --git a/netbox/utilities/api.py b/netbox/utilities/api.py index 320d175c3..dc8429dbb 100644 --- a/netbox/utilities/api.py +++ b/netbox/utilities/api.py @@ -11,7 +11,9 @@ from rest_framework import status from rest_framework.serializers import Serializer from rest_framework.utils import formatting +from netbox.api.fields import RelatedObjectCountField from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound +from utilities.utils import count_related from .utils import dynamic_import __all__ = ( @@ -131,6 +133,23 @@ def get_prefetches_for_serializer(serializer_class, fields_to_include=None): return prefetch_fields +def get_annotations_for_serializer(serializer_class, fields_to_include=None): + """ + Return a mapping of field names to annotations to be applied to the queryset for a serializer. + """ + annotations = {} + + # If specific fields are not specified, default to all + if not fields_to_include: + fields_to_include = serializer_class.Meta.fields + + for field_name, field in serializer_class._declared_fields.items(): + if field_name in fields_to_include and type(field) is RelatedObjectCountField: + annotations[field_name] = count_related(field.model, field.related_field) + + return annotations + + def rest_api_server_error(request, *args, **kwargs): """ Handle exceptions and return a useful error message for REST API requests.