cleaned up code

This commit is contained in:
Abhimanyu Saharan 2023-01-17 02:07:05 +05:30
parent 6346e4dd60
commit 496cfb4dca
2 changed files with 13 additions and 25 deletions

View File

@ -1,7 +1,6 @@
import logging import logging
import time import time
from django.conf import settings
from django.db.backends.utils import CursorWrapper as _CursorWrapper from django.db.backends.utils import CursorWrapper as _CursorWrapper
from netbox.exceptions import DatabaseWriteDenied from netbox.exceptions import DatabaseWriteDenied
@ -10,11 +9,8 @@ logger = logging.getLogger('netbox.db')
class ReadOnlyCursorWrapper: class ReadOnlyCursorWrapper:
""" """
A read-only wrapper around a database cursor.
This wrapper prevents write operations from being performed on the database. It is used to prevent changes to the This wrapper prevents write operations from being performed on the database. It is used to prevent changes to the
database during a read-only request. It is not intended to be used directly; rather, it is applied automatically by database during a read-only request.
the ReadOnlyMiddleware. See the documentation for that class for more information.
""" """
SQL_BLACKLIST = ( SQL_BLACKLIST = (
@ -35,20 +31,19 @@ class ReadOnlyCursorWrapper:
def __init__(self, cursor, db, *args, **kwargs): def __init__(self, cursor, db, *args, **kwargs):
self.cursor = cursor self.cursor = cursor
self.db = db self.db = db
self.read_only = settings.MAINTENANCE_MODE
def __check_sql(self, sql):
if self._write_sql(sql):
raise DatabaseWriteDenied
def execute(self, sql, params=()): def execute(self, sql, params=()):
# Check the SQL # Check the SQL
if self.read_only and self._write_sql(sql): self.__check_sql(sql)
raise DatabaseWriteDenied
return self.cursor.execute(sql, params) return self.cursor.execute(sql, params)
def executemany(self, sql, param_list): def executemany(self, sql, param_list):
# Check the SQL # Check the SQL
if self.read_only and self._write_sql(sql): self.__check_sql(sql)
raise DatabaseWriteDenied
return self.cursor.executemany(sql, param_list) return self.cursor.executemany(sql, param_list)
def __getattr__(self, item): def __getattr__(self, item):

View File

@ -229,20 +229,13 @@ class DatabaseReadOnlyMiddleware(MiddlewareMixin):
if not isinstance(exception, DatabaseWriteDenied): if not isinstance(exception, DatabaseWriteDenied):
return None return None
not_allowed_methods = ['POST', 'PUT', 'PATCH', 'DELETE']
error_message = 'The database is currently in read-only mode. Please try again later.' error_message = 'The database is currently in read-only mode. Please try again later.'
status_code = 503 status_code = 503
# If the request is an API request, return a 503 Service Unavailable response if is_api_request(request):
if is_api_request(request) and request.method in not_allowed_methods: return JsonResponse({'detail': error_message}, status=status_code)
return JsonResponse({'detail': error_message, }, status=status_code)
else: else:
# Handle exceptions # Display a message to the user
if request.method in not_allowed_methods: messages.error(request, error_message)
# Display a message to the user # Redirect back to the referring page
messages.error(request, error_message) return HttpResponseReload(request)
# Redirect back to the referring page
return HttpResponseReload(request)
else:
return HttpResponse(error_message, status=status_code)