
* feat: adding new vlm-models support Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the transformers Signed-off-by: Peter Staar <taa@zurich.ibm.com> * got microsoft/Phi-4-multimodal-instruct to work Signed-off-by: Peter Staar <taa@zurich.ibm.com> * working on vlm's Signed-off-by: Peter Staar <taa@zurich.ibm.com> * refactoring the VLM part Signed-off-by: Peter Staar <taa@zurich.ibm.com> * all working, now serious refacgtoring necessary Signed-off-by: Peter Staar <taa@zurich.ibm.com> * refactoring the download_model Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added the formulate_prompt Signed-off-by: Peter Staar <taa@zurich.ibm.com> * pixtral 12b runs via MLX and native transformers Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added the VlmPredictionToken Signed-off-by: Peter Staar <taa@zurich.ibm.com> * refactoring minimal_vlm_pipeline Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the MyPy Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added pipeline_model_specializations file Signed-off-by: Peter Staar <taa@zurich.ibm.com> * need to get Phi4 working again ... Signed-off-by: Peter Staar <taa@zurich.ibm.com> * finalising last points for vlms support Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the pipeline for Phi4 Signed-off-by: Peter Staar <taa@zurich.ibm.com> * streamlining all code Signed-off-by: Peter Staar <taa@zurich.ibm.com> * reformatted the code Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixing the tests Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added the html backend to the VLM pipeline Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the static load_from_doctags Signed-off-by: Peter Staar <taa@zurich.ibm.com> * restore stable imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use AutoModelForVision2Seq for Pixtral and review example (including rename) Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove unused value Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * refactor instances of VLM models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * skip compare example in CI Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use lowercase and uppercase only Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add new minimal_vlm example and refactor pipeline_options_vlm_model for cleaner import Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename pipeline_vlm_model_spec Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move more argument to options and simplify model init Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add supported_devices Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove not-needed function Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * exclude minimal_vlm Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * missing file Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add message for transformers version Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename to specs Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use module import and remove MLX from non-darwin Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove hf_vlm_model and add extra_generation_args Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use single HF VLM model class Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove torch type Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add docs for vision models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Peter Staar <taa@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
import logging
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
|
|
|
_log = logging.getLogger(__name__)
|
|
|
|
|
|
def decide_device(
|
|
accelerator_device: str, supported_devices: Optional[List[AcceleratorDevice]] = None
|
|
) -> str:
|
|
r"""
|
|
Resolve the device based on the acceleration options and the available devices in the system.
|
|
|
|
Rules:
|
|
1. AUTO: Check for the best available device on the system.
|
|
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
|
|
"""
|
|
device = "cpu"
|
|
|
|
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
|
|
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
|
|
|
|
if supported_devices is not None:
|
|
if has_cuda and AcceleratorDevice.CUDA not in supported_devices:
|
|
_log.info(
|
|
f"Removing CUDA from available devices because it is not in {supported_devices=}"
|
|
)
|
|
has_cuda = False
|
|
if has_mps and AcceleratorDevice.MPS not in supported_devices:
|
|
_log.info(
|
|
f"Removing MPS from available devices because it is not in {supported_devices=}"
|
|
)
|
|
has_mps = False
|
|
|
|
if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
|
|
if has_cuda:
|
|
device = "cuda:0"
|
|
elif has_mps:
|
|
device = "mps"
|
|
|
|
elif accelerator_device.startswith("cuda"):
|
|
if has_cuda:
|
|
# if cuda device index specified extract device id
|
|
parts = accelerator_device.split(":")
|
|
if len(parts) == 2 and parts[1].isdigit():
|
|
# select cuda device's id
|
|
cuda_index = int(parts[1])
|
|
if cuda_index < torch.cuda.device_count():
|
|
device = f"cuda:{cuda_index}"
|
|
else:
|
|
_log.warning(
|
|
"CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.",
|
|
cuda_index,
|
|
)
|
|
elif len(parts) == 1: # just "cuda"
|
|
device = "cuda:0"
|
|
else:
|
|
_log.warning(
|
|
"Invalid CUDA device format '%s'. Fall back to 'CPU'",
|
|
accelerator_device,
|
|
)
|
|
else:
|
|
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
|
|
|
|
elif accelerator_device == AcceleratorDevice.MPS.value:
|
|
if has_mps:
|
|
device = "mps"
|
|
else:
|
|
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
|
|
|
|
elif accelerator_device == AcceleratorDevice.CPU.value:
|
|
device = "cpu"
|
|
|
|
else:
|
|
_log.warning(
|
|
"Unknown device option '%s'. Fall back to 'CPU'", accelerator_device
|
|
)
|
|
|
|
_log.info("Accelerator device: '%s'", device)
|
|
return device
|