
* Move expensive imports closer to usage Signed-off-by: William Easton <bill.easton@elastic.co> * DCO Remediation Commit for William Easton <bill.easton@elastic.co> I, William Easton <bill.easton@elastic.co>, hereby add my Signed-off-by to this commit: 8a7412ce5bb131a01bb6403067aeb948c9093b0b Signed-off-by: William Easton <bill.easton@elastic.co> * formatting fixes Signed-off-by: William Easton <bill.easton@elastic.co> * DCO Remediation Commit for William Easton <bill.easton@elastic.co> I, William Easton <bill.easton@elastic.co>, hereby add my Signed-off-by to this commit: 8a7412ce5bb131a01bb6403067aeb948c9093b0b I, William Easton <bill.easton@elastic.co>, hereby add my Signed-off-by to this commit: 963e34325071db5e844841f10c27b396a054a0a1 Signed-off-by: William Easton <bill.easton@elastic.co> * Fix baseocrmodel test issue Signed-off-by: William Easton <bill.easton@elastic.co> --------- Signed-off-by: William Easton <bill.easton@elastic.co>
84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
import logging
|
|
from typing import List, Optional
|
|
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
|
|
|
_log = logging.getLogger(__name__)
|
|
|
|
|
|
def decide_device(
|
|
accelerator_device: str, supported_devices: Optional[List[AcceleratorDevice]] = None
|
|
) -> str:
|
|
r"""
|
|
Resolve the device based on the acceleration options and the available devices in the system.
|
|
|
|
Rules:
|
|
1. AUTO: Check for the best available device on the system.
|
|
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
|
|
"""
|
|
import torch
|
|
|
|
device = "cpu"
|
|
|
|
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
|
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
|
|
|
if supported_devices is not None:
|
|
if has_cuda and AcceleratorDevice.CUDA not in supported_devices:
|
|
_log.info(
|
|
f"Removing CUDA from available devices because it is not in {supported_devices=}"
|
|
)
|
|
has_cuda = False
|
|
if has_mps and AcceleratorDevice.MPS not in supported_devices:
|
|
_log.info(
|
|
f"Removing MPS from available devices because it is not in {supported_devices=}"
|
|
)
|
|
has_mps = False
|
|
|
|
if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
|
|
if has_cuda:
|
|
device = "cuda:0"
|
|
elif has_mps:
|
|
device = "mps"
|
|
|
|
elif accelerator_device.startswith("cuda"):
|
|
if has_cuda:
|
|
# if cuda device index specified extract device id
|
|
parts = accelerator_device.split(":")
|
|
if len(parts) == 2 and parts[1].isdigit():
|
|
# select cuda device's id
|
|
cuda_index = int(parts[1])
|
|
if cuda_index < torch.cuda.device_count():
|
|
device = f"cuda:{cuda_index}"
|
|
else:
|
|
_log.warning(
|
|
"CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.",
|
|
cuda_index,
|
|
)
|
|
elif len(parts) == 1: # just "cuda"
|
|
device = "cuda:0"
|
|
else:
|
|
_log.warning(
|
|
"Invalid CUDA device format '%s'. Fall back to 'CPU'",
|
|
accelerator_device,
|
|
)
|
|
else:
|
|
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
|
|
|
elif accelerator_device == AcceleratorDevice.MPS.value:
|
|
if has_mps:
|
|
device = "mps"
|
|
else:
|
|
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
|
|
|
elif accelerator_device == AcceleratorDevice.CPU.value:
|
|
device = "cpu"
|
|
|
|
else:
|
|
_log.warning(
|
|
"Unknown device option '%s'. Fall back to 'CPU'", accelerator_device
|
|
)
|
|
|
|
_log.info("Accelerator device: '%s'", device)
|
|
return device
|