diff --git a/netbox/dcim/models/devices.py b/netbox/dcim/models/devices.py index f0c116b0e..abd29d4e3 100644 --- a/netbox/dcim/models/devices.py +++ b/netbox/dcim/models/devices.py @@ -8,7 +8,7 @@ from django.core.exceptions import ValidationError from django.core.files.storage import default_storage from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models -from django.db.models import F, ProtectedError +from django.db.models import F, ProtectedError, prefetch_related_objects from django.db.models.functions import Lower from django.db.models.signals import post_save from django.urls import reverse @@ -28,6 +28,7 @@ from netbox.models import NestedGroupModel, OrganizationalModel, PrimaryModel from netbox.models.mixins import WeightMixin from netbox.models.features import ContactsMixin, ImageAttachmentsMixin from utilities.fields import ColorField, CounterCacheField +from utilities.prefetch import get_prefetchable_fields from utilities.tracking import TrackingModelMixin from .device_components import * from .mixins import RenderConfigMixin @@ -924,7 +925,10 @@ class Device( if cf_defaults := CustomField.objects.get_defaults_for_model(model): for component in components: component.custom_field_data = cf_defaults - model.objects.bulk_create(components) + components = model.objects.bulk_create(components) + # Prefetch related objects to minimize queries needed during post_save + prefetch_fields = get_prefetchable_fields(model) + prefetch_related_objects(components, *prefetch_fields) # Manually send the post_save signal for each of the newly created components for component in components: post_save.send( diff --git a/netbox/utilities/prefetch.py b/netbox/utilities/prefetch.py new file mode 100644 index 000000000..c73a3fd4f --- /dev/null +++ b/netbox/utilities/prefetch.py @@ -0,0 +1,34 @@ +from django.contrib.contenttypes.fields import GenericRelation +from django.db.models import ManyToManyField +from django.db.models.fields.related import ForeignObjectRel +from taggit.managers import TaggableManager + +__all__ = ( + 'get_prefetchable_fields', +) + + +def get_prefetchable_fields(model): + """ + Return a list containing the names of all fields on the given model which support prefetching. + """ + field_names = [] + + for field in model._meta.get_fields(): + # Forward relations (e.g. ManyToManyFields) + if isinstance(field, ManyToManyField): + field_names.append(field.name) + + # Reverse relations (e.g. reverse ForeignKeys, reverse M2M) + elif isinstance(field, ForeignObjectRel): + field_names.append(field.get_accessor_name()) + + # Generic relations + elif isinstance(field, GenericRelation): + field_names.append(field.name) + + # Tags + elif isinstance(field, TaggableManager): + field_names.append(field.name) + + return field_names diff --git a/netbox/utilities/tests/test_prefetch.py b/netbox/utilities/tests/test_prefetch.py new file mode 100644 index 000000000..9da35c12e --- /dev/null +++ b/netbox/utilities/tests/test_prefetch.py @@ -0,0 +1,17 @@ +from circuits.models import Circuit, Provider +from utilities.prefetch import get_prefetchable_fields +from utilities.testing.base import TestCase + + +class GetPrefetchableFieldsTest(TestCase): + """ + Verify the operation of get_prefetchable_fields() + """ + def test_get_prefetchable_fields(self): + field_names = get_prefetchable_fields(Provider) + self.assertIn('asns', field_names) # ManyToManyField + self.assertIn('circuits', field_names) # Reverse relation + self.assertIn('tags', field_names) # Tags + + field_names = get_prefetchable_fields(Circuit) + self.assertIn('group_assignments', field_names) # Generic relation