fix: allow custom torch_dtype in vlm models (#1735)
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -99,6 +99,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
self.vlm_model = model_cls.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=self.vlm_options.torch_dtype,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
|
||||
Reference in New Issue
Block a user