Restore advisory PostgreSQL locks

This commit is contained in:
Jeremy Stretch 2023-06-19 14:53:04 -04:00
parent 3cad6494ee
commit 0a9dcd31d9

View File

@ -2,6 +2,7 @@ from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
from django.db import transaction from django.db import transaction
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django_pglocks import advisory_lock from django_pglocks import advisory_lock
from drf_spectacular.utils import extend_schema
from netaddr import IPSet from netaddr import IPSet
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
@ -215,6 +216,7 @@ class AvailableObjectsView(ObjectValidationMixin, APIView):
""" """
read_serializer_class = None read_serializer_class = None
write_serializer_class = None write_serializer_class = None
advisory_lock_key = None
def get_parent(self, request, pk): def get_parent(self, request, pk):
""" """
@ -262,15 +264,12 @@ class AvailableObjectsView(ObjectValidationMixin, APIView):
return Response(serializer.data) return Response(serializer.data)
# TODO: Fix OpenAPI schema # TODO: Fix OpenAPI schema
# @extend_schema(methods=["post"], responses={201: serializers.ASNSerializer(many=True)}) # @extend_schema(methods=["post"], responses={201: serializer(many=True)})
# TODO: Restore advisory lock
# @advisory_lock(ADVISORY_LOCK_KEYS['available-asns'])
def post(self, request, pk): def post(self, request, pk):
self.queryset = self.queryset.restrict(request.user, 'add') self.queryset = self.queryset.restrict(request.user, 'add')
parent = self.get_parent(request, pk) parent = self.get_parent(request, pk)
available_objects = self.get_available_objects(parent)
# Normalize to a list of objects # Normalize request data to a list of objects
requested_objects = request.data if isinstance(request.data, list) else [request.data] requested_objects = request.data if isinstance(request.data, list) else [request.data]
# Serialize and validate the request data # Serialize and validate the request data
@ -284,44 +283,49 @@ class AvailableObjectsView(ObjectValidationMixin, APIView):
status=status.HTTP_400_BAD_REQUEST status=status.HTTP_400_BAD_REQUEST
) )
# Determine if the requested number of objects is available with advisory_lock(ADVISORY_LOCK_KEYS[self.advisory_lock_key]):
if not self.check_sufficient_available(serializer.validated_data, available_objects): available_objects = self.get_available_objects(parent)
# TODO: Raise exception instead?
return Response(
{
"detail": f"Insufficient resources are available to satisfy the request"
},
status=status.HTTP_409_CONFLICT
)
# Prepare object data for deserialization # Determine if the requested number of objects is available
requested_objects = self.prep_object_data(serializer.validated_data, available_objects, parent) if not self.check_sufficient_available(serializer.validated_data, available_objects):
# TODO: Raise exception instead?
return Response(
{
"detail": f"Insufficient resources are available to satisfy the request"
},
status=status.HTTP_409_CONFLICT
)
# Initialize the serializer with a list or a single object depending on what was requested # Prepare object data for deserialization
serializer_class = get_serializer_for_model(self.queryset.model) requested_objects = self.prep_object_data(serializer.validated_data, available_objects, parent)
context = {'request': request}
if isinstance(request.data, list):
serializer = serializer_class(data=requested_objects, many=True, context=context)
else:
serializer = serializer_class(data=requested_objects[0], context=context)
# Create the new IP address(es) # Initialize the serializer with a list or a single object depending on what was requested
if serializer.is_valid(): serializer_class = get_serializer_for_model(self.queryset.model)
context = {'request': request}
if isinstance(request.data, list):
serializer = serializer_class(data=requested_objects, many=True, context=context)
else:
serializer = serializer_class(data=requested_objects[0], context=context)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
# Create the new IP address(es)
try: try:
with transaction.atomic(): with transaction.atomic():
created = serializer.save() created = serializer.save()
self._validate_objects(created) self._validate_objects(created)
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise PermissionDenied() raise PermissionDenied()
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.data, status=status.HTTP_201_CREATED)
class AvailableASNsView(AvailableObjectsView): class AvailableASNsView(AvailableObjectsView):
queryset = ASN.objects.all() queryset = ASN.objects.all()
read_serializer_class = serializers.AvailableASNSerializer read_serializer_class = serializers.AvailableASNSerializer
write_serializer_class = serializers.AvailableASNSerializer write_serializer_class = serializers.AvailableASNSerializer
advisory_lock_key = 'available-asns'
def get_parent(self, request, pk): def get_parent(self, request, pk):
return get_object_or_404(ASNRange.objects.restrict(request.user), pk=pk) return get_object_or_404(ASNRange.objects.restrict(request.user), pk=pk)
@ -345,6 +349,7 @@ class AvailableASNsView(AvailableObjectsView):
return requested_objects return requested_objects
# TODO: Move me
def get_next_available(ipset, prefix_size): def get_next_available(ipset, prefix_size):
for available_prefix in ipset.iter_cidrs(): for available_prefix in ipset.iter_cidrs():
if prefix_size >= available_prefix.prefixlen: if prefix_size >= available_prefix.prefixlen:
@ -358,6 +363,7 @@ class AvailablePrefixesView(AvailableObjectsView):
queryset = Prefix.objects.all() queryset = Prefix.objects.all()
read_serializer_class = serializers.AvailablePrefixSerializer read_serializer_class = serializers.AvailablePrefixSerializer
write_serializer_class = serializers.PrefixLengthSerializer write_serializer_class = serializers.PrefixLengthSerializer
advisory_lock_key = 'available-prefixes'
def get_parent(self, request, pk): def get_parent(self, request, pk):
return get_object_or_404(Prefix.objects.restrict(request.user), pk=pk) return get_object_or_404(Prefix.objects.restrict(request.user), pk=pk)
@ -399,6 +405,7 @@ class AvailableIPAddressesView(AvailableObjectsView):
queryset = IPAddress.objects.all() queryset = IPAddress.objects.all()
read_serializer_class = serializers.AvailableIPSerializer read_serializer_class = serializers.AvailableIPSerializer
write_serializer_class = serializers.AvailableIPSerializer write_serializer_class = serializers.AvailableIPSerializer
advisory_lock_key = 'available-ips'
def get_available_objects(self, parent, limit=None): def get_available_objects(self, parent, limit=None):
# Calculate available IPs within the parent # Calculate available IPs within the parent
@ -442,6 +449,7 @@ class AvailableVLANsView(AvailableObjectsView):
queryset = VLAN.objects.all() queryset = VLAN.objects.all()
read_serializer_class = serializers.AvailableVLANSerializer read_serializer_class = serializers.AvailableVLANSerializer
write_serializer_class = serializers.CreateAvailableVLANSerializer write_serializer_class = serializers.CreateAvailableVLANSerializer
advisory_lock_key = 'available-vlans'
def get_parent(self, request, pk): def get_parent(self, request, pk):
return get_object_or_404(VLANGroup.objects.restrict(request.user), pk=pk) return get_object_or_404(VLANGroup.objects.restrict(request.user), pk=pk)