diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index 374f575..679e80c 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -57,7 +57,10 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel): artifacts_path, torch_dtype=torch.bfloat16, _attn_implementation=( - "flash_attention_2" if self.device.startswith("cuda") else "eager" + "flash_attention_2" + if self.device.startswith("cuda") + and accelerator_options.cuda_use_flash_attention2 + else "eager" ), ).to(self.device)