feat: Support cuda:n GPU device allocation (#694)
* Adding multi-gpu support, and cuda device allocation Signed-off-by: ahn <ahn@zurich.ibm.com> * Fixes pydantic exception with cuda:n Signed-off-by: ahn <ahn@zurich.ibm.com> * Pydantic field validator and comment restored. Signed-off-by: ahn <ahn@zurich.ibm.com> * chore: Accept AcceleratorDevice enum type Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Resetted some options to default, removed EasyOCR model wrap. Signed-off-by: ahn <ahn@zurich.ibm.com> * Fixed rebased issues Signed-off-by: ahn <ahn@zurich.ibm.com> * Revert accelerator test options Signed-off-by: ahn <ahn@zurich.ibm.com> --------- Signed-off-by: ahn <ahn@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: ahn <ahn@sonny.zuvela.ibm.com> Co-authored-by: ahn <ahn@zurich.ibm.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
428b656793
commit
77eb77bdc2
@ -1,11 +1,26 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator
|
from pydantic import (
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
AnyUrl,
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
validator,
|
||||||
|
)
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
SettingsConfigDict,
|
||||||
|
)
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -25,7 +40,18 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_threads: int = 4
|
num_threads: int = 4
|
||||||
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
device: Union[str, AcceleratorDevice] = "auto"
|
||||||
|
|
||||||
|
@field_validator("device")
|
||||||
|
def validate_device(cls, value):
|
||||||
|
# "auto", "cpu", "cuda", "mps", or "cuda:N"
|
||||||
|
if value in {d.value for d in AcceleratorDevice} or re.match(
|
||||||
|
r"^cuda(:\d+)?$", value
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -41,7 +67,6 @@ class AcceleratorOptions(BaseSettings):
|
|||||||
"""
|
"""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
input_num_threads = data.get("num_threads")
|
input_num_threads = data.get("num_threads")
|
||||||
|
|
||||||
# Check if to set the num_threads from the alternative envvar
|
# Check if to set the num_threads from the alternative envvar
|
||||||
if input_num_threads is None:
|
if input_num_threads is None:
|
||||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||||
|
@ -7,36 +7,62 @@ from docling.datamodel.pipeline_options import AcceleratorDevice
|
|||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def decide_device(accelerator_device: AcceleratorDevice) -> str:
|
def decide_device(accelerator_device: str) -> str:
|
||||||
r"""
|
r"""
|
||||||
Resolve the device based on the acceleration options and the available devices in the system
|
Resolve the device based on the acceleration options and the available devices in the system.
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
1. AUTO: Check for the best available device on the system.
|
1. AUTO: Check for the best available device on the system.
|
||||||
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
|
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
|
||||||
"""
|
"""
|
||||||
cuda_index = 0
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
|
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
|
||||||
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
||||||
|
|
||||||
if accelerator_device == AcceleratorDevice.AUTO:
|
if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
device = f"cuda:{cuda_index}"
|
device = "cuda:0"
|
||||||
elif has_mps:
|
elif has_mps:
|
||||||
device = "mps"
|
device = "mps"
|
||||||
|
|
||||||
|
elif accelerator_device.startswith("cuda"):
|
||||||
|
if has_cuda:
|
||||||
|
# if cuda device index specified extract device id
|
||||||
|
parts = accelerator_device.split(":")
|
||||||
|
if len(parts) == 2 and parts[1].isdigit():
|
||||||
|
# select cuda device's id
|
||||||
|
cuda_index = int(parts[1])
|
||||||
|
if cuda_index < torch.cuda.device_count():
|
||||||
|
device = f"cuda:{cuda_index}"
|
||||||
|
else:
|
||||||
|
_log.warning(
|
||||||
|
"CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.",
|
||||||
|
cuda_index,
|
||||||
|
)
|
||||||
|
elif len(parts) == 1: # just "cuda"
|
||||||
|
device = "cuda:0"
|
||||||
|
else:
|
||||||
|
_log.warning(
|
||||||
|
"Invalid CUDA device format '%s'. Fall back to 'CPU'",
|
||||||
|
accelerator_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
||||||
|
|
||||||
|
elif accelerator_device == AcceleratorDevice.MPS.value:
|
||||||
|
if has_mps:
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
||||||
|
|
||||||
|
elif accelerator_device == AcceleratorDevice.CPU.value:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if accelerator_device == AcceleratorDevice.CUDA:
|
_log.warning(
|
||||||
if has_cuda:
|
"Unknown device option '%s'. Fall back to 'CPU'", accelerator_device
|
||||||
device = f"cuda:{cuda_index}"
|
)
|
||||||
else:
|
|
||||||
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
|
||||||
elif accelerator_device == AcceleratorDevice.MPS:
|
|
||||||
if has_mps:
|
|
||||||
device = "mps"
|
|
||||||
else:
|
|
||||||
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
|
||||||
|
|
||||||
_log.info("Accelerator device: '%s'", device)
|
_log.info("Accelerator device: '%s'", device)
|
||||||
return device
|
return device
|
||||||
|
@ -30,6 +30,9 @@ def main():
|
|||||||
# num_threads=8, device=AcceleratorDevice.CUDA
|
# num_threads=8, device=AcceleratorDevice.CUDA
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# easyocr doesnt support cuda:N allocation, defaults to cuda:0
|
||||||
|
# accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:1")
|
||||||
|
|
||||||
pipeline_options = PdfPipelineOptions()
|
pipeline_options = PdfPipelineOptions()
|
||||||
pipeline_options.accelerator_options = accelerator_options
|
pipeline_options.accelerator_options = accelerator_options
|
||||||
pipeline_options.do_ocr = True
|
pipeline_options.do_ocr = True
|
||||||
|
Loading…
Reference in New Issue
Block a user