feat: Support cuda:n GPU device allocation (#694)

* Adding multi-gpu support, and cuda device allocation

Signed-off-by: ahn <ahn@zurich.ibm.com>

* Fixes pydantic exception with cuda:n
Signed-off-by: ahn <ahn@zurich.ibm.com>

* Pydantic field validator and comment restored.

Signed-off-by: ahn <ahn@zurich.ibm.com>

* chore: Accept AcceleratorDevice enum type

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Resetted some options to default, removed EasyOCR model wrap.
Signed-off-by: ahn <ahn@zurich.ibm.com>

* Fixed rebased issues
Signed-off-by: ahn <ahn@zurich.ibm.com>

* Revert accelerator test options
Signed-off-by: ahn <ahn@zurich.ibm.com>

---------

Signed-off-by: ahn <ahn@zurich.ibm.com>
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Co-authored-by: ahn <ahn@sonny.zuvela.ibm.com>
Co-authored-by: ahn <ahn@zurich.ibm.com>
Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Ahmed Nassar
2025-02-17 11:31:13 +01:00
committed by GitHub
parent 428b656793
commit 77eb77bdc2
3 changed files with 73 additions and 19 deletions

View File

@@ -1,11 +1,26 @@
import logging
import os
import re
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
validator,
)
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from typing_extensions import deprecated
_log = logging.getLogger(__name__)
@@ -25,7 +40,18 @@ class AcceleratorOptions(BaseSettings):
)
num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO
device: Union[str, AcceleratorDevice] = "auto"
@field_validator("device")
def validate_device(cls, value):
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
r"^cuda(:\d+)?$", value
):
return value
raise ValueError(
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
)
@model_validator(mode="before")
@classmethod
@@ -41,7 +67,6 @@ class AcceleratorOptions(BaseSettings):
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")