feat: new vlm-models support (#1570)

* feat: adding new vlm-models support

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the transformers

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* got microsoft/Phi-4-multimodal-instruct to work

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* working on vlm's

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* refactoring the VLM part

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* all working, now serious refacgtoring necessary

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* refactoring the download_model

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added the formulate_prompt

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* pixtral 12b runs via MLX and native transformers

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added the VlmPredictionToken

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* refactoring minimal_vlm_pipeline

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the MyPy

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added pipeline_model_specializations file

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* need to get Phi4 working again ...

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* finalising last points for vlms support

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the pipeline for Phi4

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* streamlining all code

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* reformatted the code

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixing the tests

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added the html backend to the VLM pipeline

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the static load_from_doctags

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* restore stable imports

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use AutoModelForVision2Seq for Pixtral and review example (including rename)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove unused value

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* refactor instances of VLM models

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* skip compare example in CI

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use lowercase and uppercase only

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add new minimal_vlm example and refactor pipeline_options_vlm_model for cleaner import

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename pipeline_vlm_model_spec

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* move more argument to options and simplify model init

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add supported_devices

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove not-needed function

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* exclude minimal_vlm

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* missing file

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add message for transformers version

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename to specs

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use module import and remove MLX from non-darwin

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove hf_vlm_model and add extra_generation_args

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use single HF VLM model class

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove torch type

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add docs for vision models

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Peter W. J. Staar
2025-06-02 17:01:06 +02:00
committed by GitHub
parent 08dcacc5cb
commit cfdf4cea25
46 changed files with 1968 additions and 1902 deletions

View File

@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ApiVlmOptions
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
from docling.exceptions import OperationNotAllowed
from docling.models.base_model import BasePageModel
from docling.utils.api_image_request import api_image_request

View File

@@ -11,9 +11,10 @@ from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import binary_dilation, find_objects, label
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseModelWithOptions, BasePageModel

View File

@@ -16,9 +16,10 @@ from docling_core.types.doc.labels import CodeLanguageLabel
from PIL import Image, ImageOps
from pydantic import BaseModel
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
@@ -117,20 +118,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.2",
local_dir=local_dir,
force=force,
progress=progress,
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if a given element in a document can be processed by the model.

View File

@@ -13,8 +13,9 @@ from docling_core.types.doc import (
from PIL import Image
from pydantic import BaseModel
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.models.base_model import BaseEnrichmentModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
@@ -105,20 +106,14 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/DocumentFigureClassifier",
force_download=force,
local_dir=local_dir,
revision="v1.0.1",
local_dir=local_dir,
force=force,
progress=progress,
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if the given element can be processed by the classifier.

View File

@@ -9,11 +9,10 @@ import numpy
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
)

View File

@@ -1,182 +0,0 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import ( # type: ignore
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
)
device = decide_device(accelerator_options.device)
self.device = device
_log.debug(f"Available device for HuggingFace VLM: {device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
)
self.param_quantized = vlm_options.quantized # False
self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)
else:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.param_question},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs, max_new_tokens=4096, use_cache=True
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
page_tags = generated_texts
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@@ -10,11 +10,12 @@ from docling_core.types.doc import DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder
@@ -83,20 +84,14 @@ class LayoutModel(BasePageModel):
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/docling-models",
force_download=force,
revision="v2.2.0",
local_dir=local_dir,
revision="v2.1.0",
force=force,
progress=progress,
)
return Path(download_path)
def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
):

View File

@@ -8,10 +8,10 @@ from typing import Optional, Type
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrMacOptions,
OcrOptions,
)

View File

@@ -5,8 +5,8 @@ from typing import Optional, Type, Union
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionApiOptions,
PictureDescriptionBaseOptions,
)

View File

@@ -13,8 +13,8 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co
)
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
)
from docling.models.base_model import (

View File

@@ -4,16 +4,21 @@ from typing import Optional, Type, Union
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
PictureDescriptionVlmOptions,
)
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
class PictureDescriptionVlmModel(
PictureDescriptionBaseModel, HuggingFaceModelDownloadMixin
):
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionVlmOptions
@@ -66,26 +71,6 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
self.provenance = f"{self.options.repo_id}"
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
)
return Path(download_path)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
from transformers import GenerationConfig

View File

@@ -7,11 +7,10 @@ import numpy
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
OcrOptions,
RapidOcrOptions,
)

View File

@@ -13,16 +13,16 @@ from docling_core.types.doc.page import (
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
TableFormerMode,
TableStructureOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
@@ -90,20 +90,14 @@ class TableStructureModel(BasePageModel):
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.2.0",
local_dir=local_dir,
force=force,
progress=progress,
)
return Path(download_path)
def draw_table_and_cells(
self,
conv_res: ConversionResult,

View File

@@ -13,10 +13,10 @@ import pandas as pd
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrOptions,
TesseractCliOcrOptions,
)

View File

@@ -7,10 +7,10 @@ from typing import Iterable, Optional, Type
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrOptions,
TesseractOcrOptions,
)

View File

View File

@@ -0,0 +1,40 @@
import logging
from pathlib import Path
from typing import Optional
_log = logging.getLogger(__name__)
def download_hf_model(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
revision: Optional[str] = None,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
revision=revision,
)
return Path(download_path)
class HuggingFaceModelDownloadMixin:
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
return download_hf_model(
repo_id=repo_id, local_dir=local_dir, force=force, progress=progress
)

View File

@@ -0,0 +1,194 @@
import importlib.metadata
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersModelType,
)
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
GenerationConfig,
)
transformers_version = importlib.metadata.version("transformers")
if (
self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct"
and transformers_version >= "4.52.0"
):
raise NotImplementedError(
f"Phi 4 only works with transformers<4.52.0 but you have {transformers_version=}. Please downgrage running pip install -U 'transformers<4.52.0'."
)
self.device = decide_device(
accelerator_options.device,
supported_devices=vlm_options.supported_devices,
)
_log.debug(f"Available device for VLM: {self.device}")
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_quantization_config: Optional[BitsAndBytesConfig] = None
if vlm_options.quantized:
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit,
llm_int8_threshold=vlm_options.llm_int8_threshold,
)
model_cls: Any = AutoModel
if (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_CAUSALLM
):
model_cls = AutoModelForCausalLM
elif (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_VISION2SEQ
):
model_cls = AutoModelForVision2Seq
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
self.vlm_model = model_cls.from_pretrained(
artifacts_path,
device_map=self.device,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
trust_remote_code=vlm_options.trust_remote_code,
)
# Load generation config
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
# Define prompt structure
prompt = self.formulate_prompt()
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device)
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,
)
yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.vlm_options.prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt

View File

@@ -4,29 +4,34 @@ from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceMlxModel(BasePageModel):
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
self.max_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
if self.enabled:
try:
@@ -39,42 +44,24 @@ class HuggingFaceMlxModel(BasePageModel):
)
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
self.apply_chat_template = apply_chat_template
self.stream_generate = stream_generate
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
artifacts_path = self.download_models(
self.vlm_options.repo_id,
)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_question = vlm_options.prompt
## Load the model
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@@ -83,12 +70,10 @@ class HuggingFaceMlxModel(BasePageModel):
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
@@ -104,16 +89,45 @@ class HuggingFaceMlxModel(BasePageModel):
)
start_time = time.time()
_log.debug("start generating ...")
# Call model to generate:
tokens: list[VlmPredictionToken] = []
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=4096,
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
if "</doctag>" in token.text:
break
@@ -121,15 +135,13 @@ class HuggingFaceMlxModel(BasePageModel):
generation_time = time.time() - start_time
page_tags = output
_log.debug(f"Generation time {generation_time:.2f} seconds.")
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
generated_tokens=tokens,
)
yield page