Docling/docling/models/code_formula_model.py
Michele Dolfi 9114ada7bc
fix: Test cases for RTL programmatic PDFs and fixes for the formula model (#903)
fix: Support for RTL programmatic documents
fix(parser): detect and handle rotated pages
fix(parser): fix bug causing duplicated text
fix(formula): improve stopping criteria
chore: update lock file
fix: temporary constrain beautifulsoup


* switch to code formula model v1.0.1 and new test pdf

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* switch to code formula model v1.0.1 and new test pdf

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* cleaned up the data folder in the tests

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* switch to code formula model v1.0.1 and new test pdf

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* added three test-files for right-to-left

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fix black

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* added new gt for test_e2e_conversion

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* added new gt for test_e2e_conversion

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* Add code to expose text direction of cell

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* new test file

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* update lock

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* fix mypy reports

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* fix example filepaths

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add test data results

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* pin wheel of latest docling-parse release

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use latest docling-core

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove debugging code

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* fix path to files in example

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* Revert unwanted RTL additions

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix test data paths in examples

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>
Co-authored-by: Peter Staar <taa@zurich.ibm.com>
Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
2025-02-07 08:43:31 +01:00

252 lines
8.4 KiB
Python

import re
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple, Union
import numpy as np
from docling_core.types.doc import (
CodeItem,
DocItemLabel,
DoclingDocument,
NodeItem,
TextItem,
)
from docling_core.types.doc.labels import CodeLanguageLabel
from PIL import Image
from pydantic import BaseModel
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.utils.accelerator_utils import decide_device
class CodeFormulaModelOptions(BaseModel):
"""
Configuration options for the CodeFormulaModel.
Attributes
----------
kind : str
Type of the model. Fixed value "code_formula".
do_code_enrichment : bool
True if code enrichment is enabled, False otherwise.
do_formula_enrichment : bool
True if formula enrichment is enabled, False otherwise.
"""
kind: Literal["code_formula"] = "code_formula"
do_code_enrichment: bool = True
do_formula_enrichment: bool = True
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
"""
Model for processing and enriching documents with code and formula predictions.
Attributes
----------
enabled : bool
True if the model is enabled, False otherwise.
options : CodeFormulaModelOptions
Configuration options for the CodeFormulaModel.
code_formula_model : CodeFormulaPredictor
The predictor model for code and formula processing.
Methods
-------
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
Initializes the CodeFormulaModel with the given configuration options.
is_processable(self, doc, element)
Determines if a given element in a document can be processed by the model.
__call__(self, doc, element_batch)
Processes the given batch of elements and enriches them with predictions.
"""
_model_repo_folder = "CodeFormula"
elements_batch_size = 5
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
expansion_factor = 0.03
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: CodeFormulaModelOptions,
accelerator_options: AcceleratorOptions,
):
"""
Initializes the CodeFormulaModel with the given configuration.
Parameters
----------
enabled : bool
True if the model is enabled, False otherwise.
artifacts_path : Path
Path to the directory containing the model artifacts.
options : CodeFormulaModelOptions
Configuration options for the model.
accelerator_options : AcceleratorOptions
Options specifying the device and number of threads for acceleration.
"""
self.enabled = enabled
self.options = options
if self.enabled:
device = decide_device(accelerator_options.device)
from docling_ibm_models.code_formula_model.code_formula_predictor import (
CodeFormulaPredictor,
)
if artifacts_path is None:
artifacts_path = self.download_models()
else:
artifacts_path = artifacts_path / self._model_repo_folder
self.code_formula_model = CodeFormulaPredictor(
artifacts_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
)
@staticmethod
def download_models(
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.1",
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if a given element in a document can be processed by the model.
Parameters
----------
doc : DoclingDocument
The document being processed.
element : NodeItem
The element within the document to check.
Returns
-------
bool
True if the element can be processed, False otherwise.
"""
return self.enabled and (
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
or (
isinstance(element, TextItem)
and element.label == DocItemLabel.FORMULA
and self.options.do_formula_enrichment
)
)
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
"""Extracts a programming language from the beginning of a string.
This function checks if the input string starts with a pattern of the form
``<_some_language_>``. If it does, it extracts the language string and returns
a tuple of (remainder, language). Otherwise, it returns the original string
and `None`.
Args:
input_string (str): The input string, which may start with ``<_language_>``.
Returns:
Tuple[str, Optional[str]]:
A tuple where:
- The first element is either:
- The remainder of the string (everything after ``<_language_>``),
if a match is found; or
- The original string, if no match is found.
- The second element is the extracted language if a match is found;
otherwise, `None`.
"""
pattern = r"^<_([^>]+)_>\s*(.*)"
match = re.match(pattern, input_string, flags=re.DOTALL)
if match:
language = str(match.group(1)) # the captured programming language
remainder = str(match.group(2)) # everything after the <_language_>
return remainder, language
else:
return input_string, None
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
"""
Converts a string to a corresponding `CodeLanguageLabel` enum member.
If the provided string does not match any value in `CodeLanguageLabel`,
it defaults to `CodeLanguageLabel.UNKNOWN`.
Args:
value (Optional[str]): The string representation of the code language or None.
Returns:
CodeLanguageLabel: The corresponding enum member if the value is valid,
otherwise `CodeLanguageLabel.UNKNOWN`.
"""
if not isinstance(value, str):
return CodeLanguageLabel.UNKNOWN
try:
return CodeLanguageLabel(value)
except ValueError:
return CodeLanguageLabel.UNKNOWN
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
"""
Processes the given batch of elements and enriches them with predictions.
Parameters
----------
doc : DoclingDocument
The document being processed.
element_batch : Iterable[ItemAndImageEnrichmentElement]
A batch of elements to be processed.
Returns
-------
Iterable[Any]
An iterable of enriched elements.
"""
if not self.enabled:
for element in element_batch:
yield element.item
return
labels: List[str] = []
images: List[Union[Image.Image, np.ndarray]] = []
elements: List[TextItem] = []
for el in element_batch:
assert isinstance(el.item, TextItem)
elements.append(el.item)
labels.append(el.item.label)
images.append(el.image)
outputs = self.code_formula_model.predict(images, labels)
for item, output in zip(elements, outputs):
if isinstance(item, CodeItem):
output, code_language = self._extract_code_language(output)
item.code_language = self._get_code_language_enum(code_language)
item.text = output
yield item