feat: Code and equation model for PDF and code blocks in markdown (#752)
* propagated changes for new CodeItem class Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * Rebased branch on latest main. changes for CodeItem Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed unused files Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * chore: update lockfile Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * pin latest docling-core Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * update docling-core pinning Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * pin docling-core Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * use new add_code in backends and update typing in MD backend Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * added if statement for backend Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed unused import Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * removed print statements Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * gt for new pdf Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * Update docling/pipeline/standard_pdf_pipeline.py Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Signed-off-by: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> * fixed doc comment of __call__ function of code_formula_model Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> * fix artifacts_path type Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move expansion_factor to base class Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
This commit is contained in:
245
docling/models/code_formula_model.py
Normal file
245
docling/models/code_formula_model.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
CodeItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
NodeItem,
|
||||
TextItem,
|
||||
)
|
||||
from docling_core.types.doc.labels import CodeLanguageLabel
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
|
||||
class CodeFormulaModelOptions(BaseModel):
|
||||
"""
|
||||
Configuration options for the CodeFormulaModel.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
kind : str
|
||||
Type of the model. Fixed value "code_formula".
|
||||
do_code_enrichment : bool
|
||||
True if code enrichment is enabled, False otherwise.
|
||||
do_formula_enrichment : bool
|
||||
True if formula enrichment is enabled, False otherwise.
|
||||
"""
|
||||
|
||||
kind: Literal["code_formula"] = "code_formula"
|
||||
do_code_enrichment: bool = True
|
||||
do_formula_enrichment: bool = True
|
||||
|
||||
|
||||
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
||||
"""
|
||||
Model for processing and enriching documents with code and formula predictions.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
enabled : bool
|
||||
True if the model is enabled, False otherwise.
|
||||
options : CodeFormulaModelOptions
|
||||
Configuration options for the CodeFormulaModel.
|
||||
code_formula_model : CodeFormulaPredictor
|
||||
The predictor model for code and formula processing.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
|
||||
Initializes the CodeFormulaModel with the given configuration options.
|
||||
is_processable(self, doc, element)
|
||||
Determines if a given element in a document can be processed by the model.
|
||||
__call__(self, doc, element_batch)
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
"""
|
||||
|
||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||
expansion_factor = 0.03
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: CodeFormulaModelOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
"""
|
||||
Initializes the CodeFormulaModel with the given configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
enabled : bool
|
||||
True if the model is enabled, False otherwise.
|
||||
artifacts_path : Path
|
||||
Path to the directory containing the model artifacts.
|
||||
options : CodeFormulaModelOptions
|
||||
Configuration options for the model.
|
||||
accelerator_options : AcceleratorOptions
|
||||
Options specifying the device and number of threads for acceleration.
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.options = options
|
||||
|
||||
if self.enabled:
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
||||
CodeFormulaPredictor,
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models_hf()
|
||||
else:
|
||||
artifacts_path = Path(artifacts_path)
|
||||
|
||||
self.code_formula_model = CodeFormulaPredictor(
|
||||
artifacts_path=artifacts_path,
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/CodeFormula",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v1.0.0",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
"""
|
||||
Determines if a given element in a document can be processed by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document being processed.
|
||||
element : NodeItem
|
||||
The element within the document to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the element can be processed, False otherwise.
|
||||
"""
|
||||
return self.enabled and (
|
||||
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
|
||||
or (
|
||||
isinstance(element, TextItem)
|
||||
and element.label == DocItemLabel.FORMULA
|
||||
and self.options.do_formula_enrichment
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
|
||||
"""Extracts a programming language from the beginning of a string.
|
||||
|
||||
This function checks if the input string starts with a pattern of the form
|
||||
``<_some_language_>``. If it does, it extracts the language string and returns
|
||||
a tuple of (remainder, language). Otherwise, it returns the original string
|
||||
and `None`.
|
||||
|
||||
Args:
|
||||
input_string (str): The input string, which may start with ``<_language_>``.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[str]]:
|
||||
A tuple where:
|
||||
- The first element is either:
|
||||
- The remainder of the string (everything after ``<_language_>``),
|
||||
if a match is found; or
|
||||
- The original string, if no match is found.
|
||||
- The second element is the extracted language if a match is found;
|
||||
otherwise, `None`.
|
||||
"""
|
||||
pattern = r"^<_([^>]+)_>\s*(.*)"
|
||||
match = re.match(pattern, input_string, flags=re.DOTALL)
|
||||
if match:
|
||||
language = str(match.group(1)) # the captured programming language
|
||||
remainder = str(match.group(2)) # everything after the <_language_>
|
||||
return remainder, language
|
||||
else:
|
||||
return input_string, None
|
||||
|
||||
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
|
||||
"""
|
||||
Converts a string to a corresponding `CodeLanguageLabel` enum member.
|
||||
|
||||
If the provided string does not match any value in `CodeLanguageLabel`,
|
||||
it defaults to `CodeLanguageLabel.UNKNOWN`.
|
||||
|
||||
Args:
|
||||
value (Optional[str]): The string representation of the code language or None.
|
||||
|
||||
Returns:
|
||||
CodeLanguageLabel: The corresponding enum member if the value is valid,
|
||||
otherwise `CodeLanguageLabel.UNKNOWN`.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
try:
|
||||
return CodeLanguageLabel(value)
|
||||
except ValueError:
|
||||
return CodeLanguageLabel.UNKNOWN
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
doc: DoclingDocument,
|
||||
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
||||
) -> Iterable[NodeItem]:
|
||||
"""
|
||||
Processes the given batch of elements and enriches them with predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document being processed.
|
||||
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
||||
A batch of elements to be processed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable[Any]
|
||||
An iterable of enriched elements.
|
||||
"""
|
||||
if not self.enabled:
|
||||
for element in element_batch:
|
||||
yield element.item
|
||||
return
|
||||
|
||||
labels: List[str] = []
|
||||
images: List[Image.Image] = []
|
||||
elements: List[TextItem] = []
|
||||
for el in element_batch:
|
||||
assert isinstance(el.item, TextItem)
|
||||
elements.append(el.item)
|
||||
labels.append(el.item.label)
|
||||
images.append(el.image)
|
||||
|
||||
outputs = self.code_formula_model.predict(images, labels)
|
||||
|
||||
for item, output in zip(elements, outputs):
|
||||
if isinstance(item, CodeItem):
|
||||
output, code_language = self._extract_code_language(output)
|
||||
item.code_language = self._get_code_language_enum(code_language)
|
||||
item.text = output
|
||||
|
||||
yield item
|
||||
Reference in New Issue
Block a user