Introduce get_annotations_for_serializer() and enable dynamic annotations

This commit is contained in:
Jeremy Stretch 2024-02-14 10:06:57 -05:00
parent 0480760ad8
commit afb5f42af3
2 changed files with 26 additions and 6 deletions

View File

@ -10,7 +10,7 @@ from rest_framework import mixins as drf_mixins
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from utilities.api import get_prefetches_for_serializer
from utilities.api import get_annotations_for_serializer, get_prefetches_for_serializer
from utilities.exceptions import AbortRequest
from . import mixins
@ -44,15 +44,16 @@ class BaseViewSet(GenericViewSet):
def get_queryset(self):
qs = super().get_queryset()
serializer_class = self.get_serializer_class()
# Dynamically resolve prefetches for included serializer fields and attach them to the queryset
prefetch = get_prefetches_for_serializer(
self.get_serializer_class(),
fields_to_include=self.requested_fields
)
if prefetch:
if prefetch := get_prefetches_for_serializer(serializer_class, fields_to_include=self.requested_fields):
qs = qs.prefetch_related(*prefetch)
# Dynamically resolve annotations for RelatedObjectCountFields on the serializer and attach them to the queryset
if annotations := get_annotations_for_serializer(serializer_class, fields_to_include=self.requested_fields):
qs = qs.annotate(**annotations)
return qs
def get_serializer(self, *args, **kwargs):

View File

@ -11,7 +11,9 @@ from rest_framework import status
from rest_framework.serializers import Serializer
from rest_framework.utils import formatting
from netbox.api.fields import RelatedObjectCountField
from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound
from utilities.utils import count_related
from .utils import dynamic_import
__all__ = (
@ -131,6 +133,23 @@ def get_prefetches_for_serializer(serializer_class, fields_to_include=None):
return prefetch_fields
def get_annotations_for_serializer(serializer_class, fields_to_include=None):
"""
Return a mapping of field names to annotations to be applied to the queryset for a serializer.
"""
annotations = {}
# If specific fields are not specified, default to all
if not fields_to_include:
fields_to_include = serializer_class.Meta.fields
for field_name, field in serializer_class._declared_fields.items():
if field_name in fields_to_include and type(field) is RelatedObjectCountField:
annotations[field_name] = count_related(field.model, field.related_field)
return annotations
def rest_api_server_error(request, *args, **kwargs):
"""
Handle exceptions and return a useful error message for REST API requests.