fix: enable cuda_use_flash_attention2 for PictureDescriptionVlmModel (#1496)
fix: enable use_cuda_flash_attention2 for PictureDescriptionVlmModel Signed-off-by: Zach Cox <zach.s.cox@gmail.com>
This commit is contained in:
parent
976e92e289
commit
cc453961a9
@ -57,7 +57,10 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
artifacts_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2" if self.device.startswith("cuda") else "eager"
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "eager"
|
||||
),
|
||||
).to(self.device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user