fix: allow custom torch_dtype in vlm models (#1735)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-06-10 03:52:15 -05:00 committed by GitHub
parent 49b10e7419
commit f7f31137f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import AnyUrl, BaseModel from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated from typing_extensions import deprecated
@ -42,6 +42,7 @@ class InlineVlmOptions(BaseVlmOptions):
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: ResponseFormat response_format: ResponseFormat
torch_dtype: Optional[str] = None
supported_devices: List[AcceleratorDevice] = [ supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU, AcceleratorDevice.CPU,
AcceleratorDevice.CUDA, AcceleratorDevice.CUDA,

View File

@ -99,6 +99,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
self.vlm_model = model_cls.from_pretrained( self.vlm_model = model_cls.from_pretrained(
artifacts_path, artifacts_path,
device_map=self.device, device_map=self.device,
torch_dtype=self.vlm_options.torch_dtype,
_attn_implementation=( _attn_implementation=(
"flash_attention_2" "flash_attention_2"
if self.device.startswith("cuda") if self.device.startswith("cuda")