import logging import torch from docling.datamodel.pipeline_options import AcceleratorDevice _log = logging.getLogger(__name__) def decide_device(accelerator_device: AcceleratorDevice) -> str: r""" Resolve the device based on the acceleration options and the available devices in the system Rules: 1. AUTO: Check for the best available device on the system. 2. User-defined: Check if the device actually exists, otherwise fall-back to CPU """ cuda_index = 0 device = "cpu" has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() if accelerator_device == AcceleratorDevice.AUTO: if has_cuda: device = f"cuda:{cuda_index}" elif has_mps: device = "mps" else: if accelerator_device == AcceleratorDevice.CUDA: if has_cuda: device = f"cuda:{cuda_index}" else: _log.warning("CUDA is not available in the system. Fall back to 'CPU'") elif accelerator_device == AcceleratorDevice.MPS: if has_mps: device = "mps" else: _log.warning("MPS is not available in the system. Fall back to 'CPU'") _log.info("Accelerator device: '%s'", device) return device