diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 16fb145..d317e7d 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,11 +1,26 @@ import logging import os +import re +import warnings from enum import Enum from pathlib import Path from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import ( + AnyUrl, + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + validator, +) +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, +) +from typing_extensions import deprecated _log = logging.getLogger(__name__) @@ -25,7 +40,18 @@ class AcceleratorOptions(BaseSettings): ) num_threads: int = 4 - device: AcceleratorDevice = AcceleratorDevice.AUTO + device: Union[str, AcceleratorDevice] = "auto" + + @field_validator("device") + def validate_device(cls, value): + # "auto", "cpu", "cuda", "mps", or "cuda:N" + if value in {d.value for d in AcceleratorDevice} or re.match( + r"^cuda(:\d+)?$", value + ): + return value + raise ValueError( + "Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) @model_validator(mode="before") @classmethod @@ -41,7 +67,6 @@ class AcceleratorOptions(BaseSettings): """ if isinstance(data, dict): input_num_threads = data.get("num_threads") - # Check if to set the num_threads from the alternative envvar if input_num_threads is None: docling_num_threads = os.getenv("DOCLING_NUM_THREADS") diff --git a/docling/utils/accelerator_utils.py b/docling/utils/accelerator_utils.py index 59b0479..8c93025 100644 --- a/docling/utils/accelerator_utils.py +++ b/docling/utils/accelerator_utils.py @@ -7,36 +7,62 @@ from docling.datamodel.pipeline_options import AcceleratorDevice _log = logging.getLogger(__name__) -def decide_device(accelerator_device: AcceleratorDevice) -> str: +def decide_device(accelerator_device: str) -> str: r""" - Resolve the device based on the acceleration options and the available devices in the system + Resolve the device based on the acceleration options and the available devices in the system. + Rules: 1. AUTO: Check for the best available device on the system. 2. User-defined: Check if the device actually exists, otherwise fall-back to CPU """ - cuda_index = 0 device = "cpu" has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() - if accelerator_device == AcceleratorDevice.AUTO: + if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto' if has_cuda: - device = f"cuda:{cuda_index}" + device = "cuda:0" elif has_mps: device = "mps" + elif accelerator_device.startswith("cuda"): + if has_cuda: + # if cuda device index specified extract device id + parts = accelerator_device.split(":") + if len(parts) == 2 and parts[1].isdigit(): + # select cuda device's id + cuda_index = int(parts[1]) + if cuda_index < torch.cuda.device_count(): + device = f"cuda:{cuda_index}" + else: + _log.warning( + "CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.", + cuda_index, + ) + elif len(parts) == 1: # just "cuda" + device = "cuda:0" + else: + _log.warning( + "Invalid CUDA device format '%s'. Fall back to 'CPU'", + accelerator_device, + ) + else: + _log.warning("CUDA is not available in the system. Fall back to 'CPU'") + + elif accelerator_device == AcceleratorDevice.MPS.value: + if has_mps: + device = "mps" + else: + _log.warning("MPS is not available in the system. Fall back to 'CPU'") + + elif accelerator_device == AcceleratorDevice.CPU.value: + device = "cpu" + else: - if accelerator_device == AcceleratorDevice.CUDA: - if has_cuda: - device = f"cuda:{cuda_index}" - else: - _log.warning("CUDA is not available in the system. Fall back to 'CPU'") - elif accelerator_device == AcceleratorDevice.MPS: - if has_mps: - device = "mps" - else: - _log.warning("MPS is not available in the system. Fall back to 'CPU'") + _log.warning( + "Unknown device option '%s'. Fall back to 'CPU'", accelerator_device + ) _log.info("Accelerator device: '%s'", device) return device diff --git a/docs/examples/run_with_accelerator.py b/docs/examples/run_with_accelerator.py index e53ab2a..6e81e85 100644 --- a/docs/examples/run_with_accelerator.py +++ b/docs/examples/run_with_accelerator.py @@ -30,6 +30,9 @@ def main(): # num_threads=8, device=AcceleratorDevice.CUDA # ) + # easyocr doesnt support cuda:N allocation, defaults to cuda:0 + # accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:1") + pipeline_options = PdfPipelineOptions() pipeline_options.accelerator_options = accelerator_options pipeline_options.do_ocr = True