fix: vlm using artifacts path (#1057)
* fix usage of artifacts path Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add granite vision to the download utils Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
c84b973959
commit
e197225739
@ -32,9 +32,19 @@ class _AvailableModels(str, Enum):
|
||||
CODE_FORMULA = "code_formula"
|
||||
PICTURE_CLASSIFIER = "picture_classifier"
|
||||
SMOLVLM = "smolvlm"
|
||||
GRANITE_VISION = "granite_vision"
|
||||
EASYOCR = "easyocr"
|
||||
|
||||
|
||||
_default_models = [
|
||||
_AvailableModels.LAYOUT,
|
||||
_AvailableModels.TABLEFORMER,
|
||||
_AvailableModels.CODE_FORMULA,
|
||||
_AvailableModels.PICTURE_CLASSIFIER,
|
||||
_AvailableModels.EASYOCR,
|
||||
]
|
||||
|
||||
|
||||
@app.command("download")
|
||||
def download(
|
||||
output_dir: Annotated[
|
||||
@ -73,7 +83,7 @@ def download(
|
||||
datefmt="[%X]",
|
||||
handlers=[RichHandler(show_level=False, show_time=False, markup=True)],
|
||||
)
|
||||
to_download = models or [m for m in _AvailableModels]
|
||||
to_download = models or _default_models
|
||||
output_dir = download_models(
|
||||
output_dir=output_dir,
|
||||
force=force,
|
||||
@ -83,6 +93,7 @@ def download(
|
||||
with_code_formula=_AvailableModels.CODE_FORMULA in to_download,
|
||||
with_picture_classifier=_AvailableModels.PICTURE_CLASSIFIER in to_download,
|
||||
with_smolvlm=_AvailableModels.SMOLVLM in to_download,
|
||||
with_granite_vision=_AvailableModels.GRANITE_VISION in to_download,
|
||||
with_easyocr=_AvailableModels.EASYOCR in to_download,
|
||||
)
|
||||
|
||||
|
@ -41,9 +41,9 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
)
|
||||
|
||||
# Initialize processor and model
|
||||
self.processor = AutoProcessor.from_pretrained(self.options.repo_id)
|
||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||
self.options.repo_id,
|
||||
artifacts_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2" if self.device.startswith("cuda") else "eager"
|
||||
|
@ -2,7 +2,10 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.pipeline_options import smolvlm_picture_description
|
||||
from docling.datamodel.pipeline_options import (
|
||||
granite_picture_description,
|
||||
smolvlm_picture_description,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.code_formula_model import CodeFormulaModel
|
||||
from docling.models.document_picture_classifier import DocumentPictureClassifier
|
||||
@ -23,7 +26,8 @@ def download_models(
|
||||
with_tableformer: bool = True,
|
||||
with_code_formula: bool = True,
|
||||
with_picture_classifier: bool = True,
|
||||
with_smolvlm: bool = True,
|
||||
with_smolvlm: bool = False,
|
||||
with_granite_vision: bool = False,
|
||||
with_easyocr: bool = True,
|
||||
):
|
||||
if output_dir is None:
|
||||
@ -73,6 +77,15 @@ def download_models(
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
if with_granite_vision:
|
||||
_log.info(f"Downloading Granite Vision model...")
|
||||
PictureDescriptionVlmModel.download_models(
|
||||
repo_id=granite_picture_description.repo_id,
|
||||
local_dir=output_dir / granite_picture_description.repo_cache_folder,
|
||||
force=force,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
if with_easyocr:
|
||||
_log.info(f"Downloading easyocr models...")
|
||||
EasyOcrModel.download_models(
|
||||
|
Loading…
Reference in New Issue
Block a user