fix: enable cuda_use_flash_attention2 for PictureDescriptionVlmModel (#1496)

fix: enable use_cuda_flash_attention2 for PictureDescriptionVlmModel

Signed-off-by: Zach Cox <zach.s.cox@gmail.com>
This commit is contained in:
Zach Cox 2025-04-30 02:02:52 -04:00 committed by GitHub
parent 976e92e289
commit cc453961a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -57,7 +57,10 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
artifacts_path,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2" if self.device.startswith("cuda") else "eager"
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
).to(self.device)