Refactor get_view_name()

This commit is contained in:
Jeremy Stretch 2024-03-21 09:39:17 -04:00
parent 78b4fa5196
commit 99144031b7

View File

@ -12,11 +12,12 @@ from django.urls import reverse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import status from rest_framework import status
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from rest_framework.utils import formatting from rest_framework.views import get_view_name as drf_get_view_name
from extras.constants import HTTP_CONTENT_TYPE_JSON
from netbox.api.fields import RelatedObjectCountField from netbox.api.fields import RelatedObjectCountField
from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound
from .utils import count_related, dict_to_filter_params, dynamic_import from .utils import count_related, dict_to_filter_params, dynamic_import, title
__all__ = ( __all__ = (
'get_annotations_for_serializer', 'get_annotations_for_serializer',
@ -32,7 +33,7 @@ __all__ = (
def get_serializer_for_model(model, prefix=''): def get_serializer_for_model(model, prefix=''):
""" """
Dynamically resolve and return the appropriate serializer for a model. Return the appropriate REST API serializer for the given model.
""" """
app_label, model_name = model._meta.label.split('.') app_label, model_name = model._meta.label.split('.')
serializer_name = f'{app_label}.api.serializers.{prefix}{model_name}Serializer' serializer_name = f'{app_label}.api.serializers.{prefix}{model_name}Serializer'
@ -48,15 +49,12 @@ def get_graphql_type_for_model(model):
""" """
Return the GraphQL type class for the given model. Return the GraphQL type class for the given model.
""" """
app_name, model_name = model._meta.label.split('.') app_label, model_name = model._meta.label.split('.')
# Object types for Django's auth models are in the users app class_name = f'{app_label}.graphql.types.{model_name}Type'
if app_name == 'auth':
app_name = 'users'
class_name = f'{app_name}.graphql.types.{model_name}Type'
try: try:
return dynamic_import(class_name) return dynamic_import(class_name)
except AttributeError: except AttributeError:
raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_name}.{model_name}") raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_label}.{model_name}")
def is_api_request(request): def is_api_request(request):
@ -64,30 +62,23 @@ def is_api_request(request):
Return True of the request is being made via the REST API. Return True of the request is being made via the REST API.
""" """
api_path = reverse('api-root') api_path = reverse('api-root')
return request.path_info.startswith(api_path) and request.content_type == HTTP_CONTENT_TYPE_JSON
return request.path_info.startswith(api_path) and request.content_type == 'application/json'
def get_view_name(view, suffix=None): def get_view_name(view):
""" """
Derive the view name from its associated model, if it has one. Fall back to DRF's built-in `get_view_name`. Derive the view name from its associated model, if it has one. Fall back to DRF's built-in `get_view_name()`.
This function is provided to DRF as its VIEW_NAME_FUNCTION.
""" """
if hasattr(view, 'queryset'): if hasattr(view, 'queryset'):
# Determine the model name from the queryset. # Derive the model name from the queryset.
name = view.queryset.model._meta.verbose_name name = title(view.queryset.model._meta.verbose_name)
name = ' '.join([w[0].upper() + w[1:] for w in name.split()]) # Capitalize each word if suffix := getattr(view, 'suffix', None):
name = f'{name} {suffix}'
return name
else: # Fall back to DRF's default behavior
# Replicate DRF's built-in behavior. return drf_get_view_name(view)
name = view.__class__.__name__
name = formatting.remove_trailing_string(name, 'View')
name = formatting.remove_trailing_string(name, 'ViewSet')
name = formatting.camelcase_to_spaces(name)
if suffix:
name += ' ' + suffix
return name
def get_prefetches_for_serializer(serializer_class, fields_to_include=None): def get_prefetches_for_serializer(serializer_class, fields_to_include=None):