fix: allow custom torch_dtype in vlm models (#1735)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-06-10 03:52:15 -05:00 committed by GitHub
parent 49b10e7419
commit f7f31137f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated
@ -42,6 +42,7 @@ class InlineVlmOptions(BaseVlmOptions):
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: ResponseFormat
torch_dtype: Optional[str] = None
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,

View File

@ -99,6 +99,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
self.vlm_model = model_cls.from_pretrained(
artifacts_path,
device_map=self.device,
torch_dtype=self.vlm_options.torch_dtype,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")