
* fix artifacts path Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add docling-models utility Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * missing formatting Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename utility to docling-tools Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename download methods and deprecation warnings Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * propagate artifacts path usage for ocr models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move function to utils Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove unused file Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * update docs Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * simplify downloading specific model(s) Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> * minor refactor Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
251 lines
8.4 KiB
Python
251 lines
8.4 KiB
Python
import re
|
|
from pathlib import Path
|
|
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
|
|
|
from docling_core.types.doc import (
|
|
CodeItem,
|
|
DocItemLabel,
|
|
DoclingDocument,
|
|
NodeItem,
|
|
TextItem,
|
|
)
|
|
from docling_core.types.doc.labels import CodeLanguageLabel
|
|
from PIL import Image
|
|
from pydantic import BaseModel
|
|
|
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
|
from docling.datamodel.pipeline_options import AcceleratorOptions
|
|
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
|
from docling.utils.accelerator_utils import decide_device
|
|
|
|
|
|
class CodeFormulaModelOptions(BaseModel):
|
|
"""
|
|
Configuration options for the CodeFormulaModel.
|
|
|
|
Attributes
|
|
----------
|
|
kind : str
|
|
Type of the model. Fixed value "code_formula".
|
|
do_code_enrichment : bool
|
|
True if code enrichment is enabled, False otherwise.
|
|
do_formula_enrichment : bool
|
|
True if formula enrichment is enabled, False otherwise.
|
|
"""
|
|
|
|
kind: Literal["code_formula"] = "code_formula"
|
|
do_code_enrichment: bool = True
|
|
do_formula_enrichment: bool = True
|
|
|
|
|
|
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
"""
|
|
Model for processing and enriching documents with code and formula predictions.
|
|
|
|
Attributes
|
|
----------
|
|
enabled : bool
|
|
True if the model is enabled, False otherwise.
|
|
options : CodeFormulaModelOptions
|
|
Configuration options for the CodeFormulaModel.
|
|
code_formula_model : CodeFormulaPredictor
|
|
The predictor model for code and formula processing.
|
|
|
|
Methods
|
|
-------
|
|
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
|
|
Initializes the CodeFormulaModel with the given configuration options.
|
|
is_processable(self, doc, element)
|
|
Determines if a given element in a document can be processed by the model.
|
|
__call__(self, doc, element_batch)
|
|
Processes the given batch of elements and enriches them with predictions.
|
|
"""
|
|
|
|
_model_repo_folder = "CodeFormula"
|
|
elements_batch_size = 5
|
|
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
|
expansion_factor = 0.03
|
|
|
|
def __init__(
|
|
self,
|
|
enabled: bool,
|
|
artifacts_path: Optional[Path],
|
|
options: CodeFormulaModelOptions,
|
|
accelerator_options: AcceleratorOptions,
|
|
):
|
|
"""
|
|
Initializes the CodeFormulaModel with the given configuration.
|
|
|
|
Parameters
|
|
----------
|
|
enabled : bool
|
|
True if the model is enabled, False otherwise.
|
|
artifacts_path : Path
|
|
Path to the directory containing the model artifacts.
|
|
options : CodeFormulaModelOptions
|
|
Configuration options for the model.
|
|
accelerator_options : AcceleratorOptions
|
|
Options specifying the device and number of threads for acceleration.
|
|
"""
|
|
self.enabled = enabled
|
|
self.options = options
|
|
|
|
if self.enabled:
|
|
device = decide_device(accelerator_options.device)
|
|
|
|
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
|
CodeFormulaPredictor,
|
|
)
|
|
|
|
if artifacts_path is None:
|
|
artifacts_path = self.download_models()
|
|
else:
|
|
artifacts_path = artifacts_path / self._model_repo_folder
|
|
|
|
self.code_formula_model = CodeFormulaPredictor(
|
|
artifacts_path=artifacts_path,
|
|
device=device,
|
|
num_threads=accelerator_options.num_threads,
|
|
)
|
|
|
|
@staticmethod
|
|
def download_models(
|
|
local_dir: Optional[Path] = None,
|
|
force: bool = False,
|
|
progress: bool = False,
|
|
) -> Path:
|
|
from huggingface_hub import snapshot_download
|
|
from huggingface_hub.utils import disable_progress_bars
|
|
|
|
if not progress:
|
|
disable_progress_bars()
|
|
download_path = snapshot_download(
|
|
repo_id="ds4sd/CodeFormula",
|
|
force_download=force,
|
|
local_dir=local_dir,
|
|
revision="v1.0.0",
|
|
)
|
|
|
|
return Path(download_path)
|
|
|
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
|
"""
|
|
Determines if a given element in a document can be processed by the model.
|
|
|
|
Parameters
|
|
----------
|
|
doc : DoclingDocument
|
|
The document being processed.
|
|
element : NodeItem
|
|
The element within the document to check.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the element can be processed, False otherwise.
|
|
"""
|
|
return self.enabled and (
|
|
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
|
|
or (
|
|
isinstance(element, TextItem)
|
|
and element.label == DocItemLabel.FORMULA
|
|
and self.options.do_formula_enrichment
|
|
)
|
|
)
|
|
|
|
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
|
|
"""Extracts a programming language from the beginning of a string.
|
|
|
|
This function checks if the input string starts with a pattern of the form
|
|
``<_some_language_>``. If it does, it extracts the language string and returns
|
|
a tuple of (remainder, language). Otherwise, it returns the original string
|
|
and `None`.
|
|
|
|
Args:
|
|
input_string (str): The input string, which may start with ``<_language_>``.
|
|
|
|
Returns:
|
|
Tuple[str, Optional[str]]:
|
|
A tuple where:
|
|
- The first element is either:
|
|
- The remainder of the string (everything after ``<_language_>``),
|
|
if a match is found; or
|
|
- The original string, if no match is found.
|
|
- The second element is the extracted language if a match is found;
|
|
otherwise, `None`.
|
|
"""
|
|
pattern = r"^<_([^>]+)_>\s*(.*)"
|
|
match = re.match(pattern, input_string, flags=re.DOTALL)
|
|
if match:
|
|
language = str(match.group(1)) # the captured programming language
|
|
remainder = str(match.group(2)) # everything after the <_language_>
|
|
return remainder, language
|
|
else:
|
|
return input_string, None
|
|
|
|
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
|
|
"""
|
|
Converts a string to a corresponding `CodeLanguageLabel` enum member.
|
|
|
|
If the provided string does not match any value in `CodeLanguageLabel`,
|
|
it defaults to `CodeLanguageLabel.UNKNOWN`.
|
|
|
|
Args:
|
|
value (Optional[str]): The string representation of the code language or None.
|
|
|
|
Returns:
|
|
CodeLanguageLabel: The corresponding enum member if the value is valid,
|
|
otherwise `CodeLanguageLabel.UNKNOWN`.
|
|
"""
|
|
if not isinstance(value, str):
|
|
return CodeLanguageLabel.UNKNOWN
|
|
|
|
try:
|
|
return CodeLanguageLabel(value)
|
|
except ValueError:
|
|
return CodeLanguageLabel.UNKNOWN
|
|
|
|
def __call__(
|
|
self,
|
|
doc: DoclingDocument,
|
|
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
|
) -> Iterable[NodeItem]:
|
|
"""
|
|
Processes the given batch of elements and enriches them with predictions.
|
|
|
|
Parameters
|
|
----------
|
|
doc : DoclingDocument
|
|
The document being processed.
|
|
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
|
A batch of elements to be processed.
|
|
|
|
Returns
|
|
-------
|
|
Iterable[Any]
|
|
An iterable of enriched elements.
|
|
"""
|
|
if not self.enabled:
|
|
for element in element_batch:
|
|
yield element.item
|
|
return
|
|
|
|
labels: List[str] = []
|
|
images: List[Image.Image] = []
|
|
elements: List[TextItem] = []
|
|
for el in element_batch:
|
|
assert isinstance(el.item, TextItem)
|
|
elements.append(el.item)
|
|
labels.append(el.item.label)
|
|
images.append(el.image)
|
|
|
|
outputs = self.code_formula_model.predict(images, labels)
|
|
|
|
for item, output in zip(elements, outputs):
|
|
if isinstance(item, CodeItem):
|
|
output, code_language = self._extract_code_language(output)
|
|
item.code_language = self._get_code_language_enum(code_language)
|
|
item.text = output
|
|
|
|
yield item
|