fix: enable cuda_use_flash_attention2 for PictureDescriptionVlmModel (#1496)
fix: enable use_cuda_flash_attention2 for PictureDescriptionVlmModel Signed-off-by: Zach Cox <zach.s.cox@gmail.com>
This commit is contained in:
parent
976e92e289
commit
cc453961a9
@ -57,7 +57,10 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
|||||||
artifacts_path,
|
artifacts_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2" if self.device.startswith("cuda") else "eager"
|
"flash_attention_2"
|
||||||
|
if self.device.startswith("cuda")
|
||||||
|
and accelerator_options.cuda_use_flash_attention2
|
||||||
|
else "eager"
|
||||||
),
|
),
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user