import json from contextlib import contextmanager from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.contrib.postgres.fields import ArrayField, RangeField from django.core.exceptions import FieldDoesNotExist from django.db import transaction from django.db.models import ManyToManyField, ManyToManyRel, JSONField from django.forms.models import model_to_dict from django.test import Client, TestCase as _TestCase from netaddr import IPNetwork from taggit.managers import TaggableManager from core.models import ObjectType from users.models import ObjectPermission, User from utilities.data import ranges_to_string from utilities.object_types import object_type_identifier from utilities.permissions import resolve_permission_type from .utils import DUMMY_CF_DATA, extract_form_failures __all__ = ( 'ModelTestCase', 'TestCase', ) class TestCase(_TestCase): user_permissions = () def setUp(self): # Create the test user and assign permissions self.user = User.objects.create_user(username='testuser') self.add_permissions(*self.user_permissions) # Initialize the test client self.client = Client() self.client.force_login(self.user) @contextmanager def cleanupSubTest(self, **params): """ Context manager that wraps subTest with automatic cleanup. All database changes within the context will be rolled back. """ sid = transaction.savepoint() try: with self.subTest(**params): yield finally: transaction.savepoint_rollback(sid) # # Permissions management # def add_permissions(self, *names): """ Assign a set of permissions to the test user. Accepts permission names in the form ._. """ for name in names: object_type, action = resolve_permission_type(name) obj_perm = ObjectPermission(name=name, actions=[action]) obj_perm.save() obj_perm.users.add(self.user) obj_perm.object_types.add(object_type) def remove_permissions(self, *names): """ Remove a set of permissions from the test user. Accepts permission names in the form ._. """ for name in names: object_type, action = resolve_permission_type(name) ObjectPermission.objects.filter( actions__contains=[action], object_types=object_type, users=self.user ).delete() # # Custom assertions # def assertHttpStatus(self, response, expected_status): """ TestCase method. Provide more detail in the event of an unexpected HTTP response. """ err_message = None # Construct an error message only if we know the test is going to fail if response.status_code != expected_status: if hasattr(response, 'data'): # REST API response; pass the response data through directly err = response.data else: # Attempt to extract form validation errors from the response HTML form_errors = extract_form_failures(response.content) err = form_errors or response.content or 'No data' err_message = f"Expected HTTP status {expected_status}; received {response.status_code}: {err}" self.assertEqual(response.status_code, expected_status, err_message) class ModelTestCase(TestCase): """ Parent class for TestCases which deal with models. """ model = None def _get_queryset(self): """ Return a base queryset suitable for use in test methods. """ return self.model.objects.all() def prepare_instance(self, instance): """ Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation against test data. For example, it can be used to decrypt a Secret's plaintext attribute. """ return instance def model_to_dict(self, instance, fields, api=False): """ Return a dictionary representation of an instance. """ # Prepare the instance and call Django's model_to_dict() to extract all fields model_dict = model_to_dict(self.prepare_instance(instance), fields=fields) # Map any additional (non-field) instance attributes that were specified for attr in fields: if hasattr(instance, attr) and attr not in model_dict: model_dict[attr] = getattr(instance, attr) for key, value in list(model_dict.items()): try: field = instance._meta.get_field(key) except FieldDoesNotExist: # Attribute is not a model field continue # Handle ManyToManyFields if value and type(field) in (ManyToManyField, ManyToManyRel, TaggableManager): # Resolve reverse M2M relationships if isinstance(field, ManyToManyRel): value = getattr(instance, field.related_name).all() if field.related_model in (ContentType, ObjectType) and api: model_dict[key] = sorted([object_type_identifier(ot) for ot in value]) else: model_dict[key] = sorted([obj.pk for obj in value]) # Handle GenericForeignKeys elif value and type(field) is GenericForeignKey: model_dict[key] = value.pk elif api: # Replace ContentType numeric IDs with . if type(getattr(instance, key)) in (ContentType, ObjectType): object_type = ObjectType.objects.get(pk=value) model_dict[key] = object_type_identifier(object_type) # Convert IPNetwork instances to strings elif type(value) is IPNetwork: model_dict[key] = str(value) else: field = instance._meta.get_field(key) # Convert ArrayFields to CSV strings if type(field) is ArrayField: if getattr(field.base_field, 'choices', None): # Values for fields with pre-defined choices can be returned as lists model_dict[key] = value elif type(field.base_field) is ArrayField: # Handle nested arrays (e.g. choice sets) model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value]) elif issubclass(type(field.base_field), RangeField): # Handle arrays of numeric ranges (e.g. VLANGroup VLAN ID ranges) model_dict[key] = ranges_to_string(value) else: model_dict[key] = ','.join([str(v) for v in value]) # JSON if type(field) is JSONField and value is not None: model_dict[key] = json.dumps(value) return model_dict # # Custom assertions # def assertInstanceEqual(self, instance, data, exclude=None, api=False): """ Compare a model instance to a dictionary, checking that its attribute values match those specified in the dictionary. :param instance: Python object instance :param data: Dictionary of test data used to define the instance :param exclude: List of fields to exclude from comparison (e.g. passwords, which get hashed) :param api: Set to True is the data is a JSON representation of the instance """ if exclude is None: exclude = [] fields = [k for k in data.keys() if k not in exclude] model_dict = self.model_to_dict(instance, fields=fields, api=api) # Omit any dictionary keys which are not instance attributes or have been excluded model_data = { k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude } self.assertDictEqual(model_dict, model_data) # Validate any custom field data, if present if getattr(instance, 'custom_field_data', None): self.assertDictEqual(instance.custom_field_data, DUMMY_CF_DATA)