diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index c1ec28a..2289c3c 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated @@ -42,6 +42,7 @@ class InlineVlmOptions(BaseVlmOptions): transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL response_format: ResponseFormat + torch_dtype: Optional[str] = None supported_devices: List[AcceleratorDevice] = [ AcceleratorDevice.CPU, AcceleratorDevice.CUDA, diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index de7f289..00fdfa5 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -99,6 +99,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix self.vlm_model = model_cls.from_pretrained( artifacts_path, device_map=self.device, + torch_dtype=self.vlm_options.torch_dtype, _attn_implementation=( "flash_attention_2" if self.device.startswith("cuda")