
* feat: adding new vlm-models support Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the transformers Signed-off-by: Peter Staar <taa@zurich.ibm.com> * got microsoft/Phi-4-multimodal-instruct to work Signed-off-by: Peter Staar <taa@zurich.ibm.com> * working on vlm's Signed-off-by: Peter Staar <taa@zurich.ibm.com> * refactoring the VLM part Signed-off-by: Peter Staar <taa@zurich.ibm.com> * all working, now serious refacgtoring necessary Signed-off-by: Peter Staar <taa@zurich.ibm.com> * refactoring the download_model Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added the formulate_prompt Signed-off-by: Peter Staar <taa@zurich.ibm.com> * pixtral 12b runs via MLX and native transformers Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added the VlmPredictionToken Signed-off-by: Peter Staar <taa@zurich.ibm.com> * refactoring minimal_vlm_pipeline Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the MyPy Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added pipeline_model_specializations file Signed-off-by: Peter Staar <taa@zurich.ibm.com> * need to get Phi4 working again ... Signed-off-by: Peter Staar <taa@zurich.ibm.com> * finalising last points for vlms support Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the pipeline for Phi4 Signed-off-by: Peter Staar <taa@zurich.ibm.com> * streamlining all code Signed-off-by: Peter Staar <taa@zurich.ibm.com> * reformatted the code Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixing the tests Signed-off-by: Peter Staar <taa@zurich.ibm.com> * added the html backend to the VLM pipeline Signed-off-by: Peter Staar <taa@zurich.ibm.com> * fixed the static load_from_doctags Signed-off-by: Peter Staar <taa@zurich.ibm.com> * restore stable imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use AutoModelForVision2Seq for Pixtral and review example (including rename) Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove unused value Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * refactor instances of VLM models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * skip compare example in CI Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use lowercase and uppercase only Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add new minimal_vlm example and refactor pipeline_options_vlm_model for cleaner import Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename pipeline_vlm_model_spec Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move more argument to options and simplify model init Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add supported_devices Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove not-needed function Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * exclude minimal_vlm Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * missing file Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add message for transformers version Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename to specs Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use module import and remove MLX from non-darwin Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove hf_vlm_model and add extra_generation_args Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use single HF VLM model class Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove torch type Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add docs for vision models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Peter Staar <taa@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
327 lines
11 KiB
Python
327 lines
11 KiB
Python
import re
|
|
from collections import Counter
|
|
from collections.abc import Iterable
|
|
from pathlib import Path
|
|
from typing import List, Literal, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
from docling_core.types.doc import (
|
|
CodeItem,
|
|
DocItemLabel,
|
|
DoclingDocument,
|
|
NodeItem,
|
|
TextItem,
|
|
)
|
|
from docling_core.types.doc.labels import CodeLanguageLabel
|
|
from PIL import Image, ImageOps
|
|
from pydantic import BaseModel
|
|
|
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
|
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
|
from docling.models.utils.hf_model_download import download_hf_model
|
|
from docling.utils.accelerator_utils import decide_device
|
|
|
|
|
|
class CodeFormulaModelOptions(BaseModel):
|
|
"""
|
|
Configuration options for the CodeFormulaModel.
|
|
|
|
Attributes
|
|
----------
|
|
kind : str
|
|
Type of the model. Fixed value "code_formula".
|
|
do_code_enrichment : bool
|
|
True if code enrichment is enabled, False otherwise.
|
|
do_formula_enrichment : bool
|
|
True if formula enrichment is enabled, False otherwise.
|
|
"""
|
|
|
|
kind: Literal["code_formula"] = "code_formula"
|
|
do_code_enrichment: bool = True
|
|
do_formula_enrichment: bool = True
|
|
|
|
|
|
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
"""
|
|
Model for processing and enriching documents with code and formula predictions.
|
|
|
|
Attributes
|
|
----------
|
|
enabled : bool
|
|
True if the model is enabled, False otherwise.
|
|
options : CodeFormulaModelOptions
|
|
Configuration options for the CodeFormulaModel.
|
|
code_formula_model : CodeFormulaPredictor
|
|
The predictor model for code and formula processing.
|
|
|
|
Methods
|
|
-------
|
|
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
|
|
Initializes the CodeFormulaModel with the given configuration options.
|
|
is_processable(self, doc, element)
|
|
Determines if a given element in a document can be processed by the model.
|
|
__call__(self, doc, element_batch)
|
|
Processes the given batch of elements and enriches them with predictions.
|
|
"""
|
|
|
|
_model_repo_folder = "ds4sd--CodeFormula"
|
|
elements_batch_size = 5
|
|
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
|
expansion_factor = 0.18
|
|
|
|
def __init__(
|
|
self,
|
|
enabled: bool,
|
|
artifacts_path: Optional[Path],
|
|
options: CodeFormulaModelOptions,
|
|
accelerator_options: AcceleratorOptions,
|
|
):
|
|
"""
|
|
Initializes the CodeFormulaModel with the given configuration.
|
|
|
|
Parameters
|
|
----------
|
|
enabled : bool
|
|
True if the model is enabled, False otherwise.
|
|
artifacts_path : Path
|
|
Path to the directory containing the model artifacts.
|
|
options : CodeFormulaModelOptions
|
|
Configuration options for the model.
|
|
accelerator_options : AcceleratorOptions
|
|
Options specifying the device and number of threads for acceleration.
|
|
"""
|
|
self.enabled = enabled
|
|
self.options = options
|
|
|
|
if self.enabled:
|
|
device = decide_device(accelerator_options.device)
|
|
|
|
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
|
CodeFormulaPredictor,
|
|
)
|
|
|
|
if artifacts_path is None:
|
|
artifacts_path = self.download_models()
|
|
else:
|
|
artifacts_path = artifacts_path / self._model_repo_folder
|
|
|
|
self.code_formula_model = CodeFormulaPredictor(
|
|
artifacts_path=str(artifacts_path),
|
|
device=device,
|
|
num_threads=accelerator_options.num_threads,
|
|
)
|
|
|
|
@staticmethod
|
|
def download_models(
|
|
local_dir: Optional[Path] = None,
|
|
force: bool = False,
|
|
progress: bool = False,
|
|
) -> Path:
|
|
return download_hf_model(
|
|
repo_id="ds4sd/CodeFormula",
|
|
revision="v1.0.2",
|
|
local_dir=local_dir,
|
|
force=force,
|
|
progress=progress,
|
|
)
|
|
|
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
|
"""
|
|
Determines if a given element in a document can be processed by the model.
|
|
|
|
Parameters
|
|
----------
|
|
doc : DoclingDocument
|
|
The document being processed.
|
|
element : NodeItem
|
|
The element within the document to check.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the element can be processed, False otherwise.
|
|
"""
|
|
return self.enabled and (
|
|
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
|
|
or (
|
|
isinstance(element, TextItem)
|
|
and element.label == DocItemLabel.FORMULA
|
|
and self.options.do_formula_enrichment
|
|
)
|
|
)
|
|
|
|
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
|
|
"""Extracts a programming language from the beginning of a string.
|
|
|
|
This function checks if the input string starts with a pattern of the form
|
|
``<_some_language_>``. If it does, it extracts the language string and returns
|
|
a tuple of (remainder, language). Otherwise, it returns the original string
|
|
and `None`.
|
|
|
|
Args:
|
|
input_string (str): The input string, which may start with ``<_language_>``.
|
|
|
|
Returns:
|
|
Tuple[str, Optional[str]]:
|
|
A tuple where:
|
|
- The first element is either:
|
|
- The remainder of the string (everything after ``<_language_>``),
|
|
if a match is found; or
|
|
- The original string, if no match is found.
|
|
- The second element is the extracted language if a match is found;
|
|
otherwise, `None`.
|
|
"""
|
|
pattern = r"^<_([^_>]+)_>\s(.*)"
|
|
match = re.match(pattern, input_string, flags=re.DOTALL)
|
|
if match:
|
|
language = str(match.group(1)) # the captured programming language
|
|
remainder = str(match.group(2)) # everything after the <_language_>
|
|
return remainder, language
|
|
else:
|
|
return input_string, None
|
|
|
|
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
|
|
"""
|
|
Converts a string to a corresponding `CodeLanguageLabel` enum member.
|
|
|
|
If the provided string does not match any value in `CodeLanguageLabel`,
|
|
it defaults to `CodeLanguageLabel.UNKNOWN`.
|
|
|
|
Args:
|
|
value (Optional[str]): The string representation of the code language or None.
|
|
|
|
Returns:
|
|
CodeLanguageLabel: The corresponding enum member if the value is valid,
|
|
otherwise `CodeLanguageLabel.UNKNOWN`.
|
|
"""
|
|
if not isinstance(value, str):
|
|
return CodeLanguageLabel.UNKNOWN
|
|
|
|
try:
|
|
return CodeLanguageLabel(value)
|
|
except ValueError:
|
|
return CodeLanguageLabel.UNKNOWN
|
|
|
|
def _get_most_frequent_edge_color(self, pil_img: Image.Image):
|
|
"""
|
|
Compute the most frequent color along the outer edges of a PIL image.
|
|
|
|
Parameters
|
|
----------
|
|
pil_img : Image.Image
|
|
A PIL Image in any mode (L, RGB, RGBA, etc.).
|
|
|
|
Returns
|
|
-------
|
|
(int) or (tuple): The most common edge color as a scalar (for grayscale) or
|
|
tuple (for RGB/RGBA).
|
|
"""
|
|
# Convert to NumPy array for easy pixel access
|
|
img_np = np.array(pil_img)
|
|
|
|
if img_np.ndim == 2:
|
|
# Grayscale-like image: shape (H, W)
|
|
# Extract edges: top row, bottom row, left col, right col
|
|
top = img_np[0, :] # shape (W,)
|
|
bottom = img_np[-1, :] # shape (W,)
|
|
left = img_np[:, 0] # shape (H,)
|
|
right = img_np[:, -1] # shape (H,)
|
|
|
|
# Concatenate all edges
|
|
edges = np.concatenate([top, bottom, left, right])
|
|
|
|
# Count frequencies
|
|
freq = Counter(edges.tolist())
|
|
most_common_value, _ = freq.most_common(1)[0]
|
|
return int(most_common_value) # single channel color
|
|
|
|
else:
|
|
# Color image: shape (H, W, C)
|
|
top = img_np[0, :, :] # shape (W, C)
|
|
bottom = img_np[-1, :, :] # shape (W, C)
|
|
left = img_np[:, 0, :] # shape (H, C)
|
|
right = img_np[:, -1, :] # shape (H, C)
|
|
|
|
# Concatenate edges along first axis
|
|
edges = np.concatenate([top, bottom, left, right], axis=0)
|
|
|
|
# Convert each color to a tuple for counting
|
|
edges_as_tuples = [tuple(pixel) for pixel in edges]
|
|
freq = Counter(edges_as_tuples)
|
|
most_common_value, _ = freq.most_common(1)[0]
|
|
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
|
|
|
|
def _pad_with_most_frequent_edge_color(
|
|
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
|
|
):
|
|
"""
|
|
Pads an image (PIL or NumPy array) using the most frequent edge color.
|
|
|
|
Parameters
|
|
----------
|
|
img : Union[Image.Image, np.ndarray]
|
|
The original image.
|
|
padding : tuple
|
|
Padding (left, top, right, bottom) in pixels.
|
|
|
|
Returns
|
|
-------
|
|
Image.Image: A new PIL image with the specified padding.
|
|
"""
|
|
if isinstance(img, np.ndarray):
|
|
pil_img = Image.fromarray(img)
|
|
else:
|
|
pil_img = img
|
|
|
|
most_freq_color = self._get_most_frequent_edge_color(pil_img)
|
|
|
|
padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
|
|
return padded_img
|
|
|
|
def __call__(
|
|
self,
|
|
doc: DoclingDocument,
|
|
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
|
) -> Iterable[NodeItem]:
|
|
"""
|
|
Processes the given batch of elements and enriches them with predictions.
|
|
|
|
Parameters
|
|
----------
|
|
doc : DoclingDocument
|
|
The document being processed.
|
|
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
|
A batch of elements to be processed.
|
|
|
|
Returns
|
|
-------
|
|
Iterable[Any]
|
|
An iterable of enriched elements.
|
|
"""
|
|
if not self.enabled:
|
|
for element in element_batch:
|
|
yield element.item
|
|
return
|
|
|
|
labels: List[str] = []
|
|
images: List[Union[Image.Image, np.ndarray]] = []
|
|
elements: List[TextItem] = []
|
|
for el in element_batch:
|
|
assert isinstance(el.item, TextItem)
|
|
elements.append(el.item)
|
|
labels.append(el.item.label)
|
|
images.append(
|
|
self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
|
|
)
|
|
|
|
outputs = self.code_formula_model.predict(images, labels)
|
|
|
|
for item, output in zip(elements, outputs):
|
|
if isinstance(item, CodeItem):
|
|
output, code_language = self._extract_code_language(output)
|
|
item.code_language = self._get_code_language_enum(code_language)
|
|
item.text = output
|
|
|
|
yield item
|