fix: allow custom torch_dtype in vlm models (#1735)
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
49b10e7419
commit
f7f31137f1
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
from typing_extensions import deprecated
|
||||
@ -42,6 +42,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
||||
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
|
||||
response_format: ResponseFormat
|
||||
|
||||
torch_dtype: Optional[str] = None
|
||||
supported_devices: List[AcceleratorDevice] = [
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
|
@ -99,6 +99,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
self.vlm_model = model_cls.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=self.vlm_options.torch_dtype,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
|
Loading…
Reference in New Issue
Block a user