feat: Support cuda:n GPU device allocation (#694)
* Adding multi-gpu support, and cuda device allocation Signed-off-by: ahn <ahn@zurich.ibm.com> * Fixes pydantic exception with cuda:n Signed-off-by: ahn <ahn@zurich.ibm.com> * Pydantic field validator and comment restored. Signed-off-by: ahn <ahn@zurich.ibm.com> * chore: Accept AcceleratorDevice enum type Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Resetted some options to default, removed EasyOCR model wrap. Signed-off-by: ahn <ahn@zurich.ibm.com> * Fixed rebased issues Signed-off-by: ahn <ahn@zurich.ibm.com> * Revert accelerator test options Signed-off-by: ahn <ahn@zurich.ibm.com> --------- Signed-off-by: ahn <ahn@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: ahn <ahn@sonny.zuvela.ibm.com> Co-authored-by: ahn <ahn@zurich.ibm.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -1,11 +1,26 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import (
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
validator,
|
||||
)
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +40,18 @@ class AcceleratorOptions(BaseSettings):
|
||||
)
|
||||
|
||||
num_threads: int = 4
|
||||
device: AcceleratorDevice = AcceleratorDevice.AUTO
|
||||
device: Union[str, AcceleratorDevice] = "auto"
|
||||
|
||||
@field_validator("device")
|
||||
def validate_device(cls, value):
|
||||
# "auto", "cpu", "cuda", "mps", or "cuda:N"
|
||||
if value in {d.value for d in AcceleratorDevice} or re.match(
|
||||
r"^cuda(:\d+)?$", value
|
||||
):
|
||||
return value
|
||||
raise ValueError(
|
||||
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -41,7 +67,6 @@ class AcceleratorOptions(BaseSettings):
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
input_num_threads = data.get("num_threads")
|
||||
|
||||
# Check if to set the num_threads from the alternative envvar
|
||||
if input_num_threads is None:
|
||||
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
|
||||
|
||||
Reference in New Issue
Block a user