perf: New revision code formula model and document picture classifier (#1140)

* new version code formula model

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* new version document picture classifier

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* new code formula model

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

* restored original code formula test pdf

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>

---------

Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>
Co-authored-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>
This commit is contained in:
Matteo 2025-03-11 09:15:28 +00:00 committed by GitHub
parent 4d64c4c0b6
commit 5e30381c0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 6 deletions

View File

@ -1,4 +1,5 @@
import re
from collections import Counter
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple, Union
@ -11,7 +12,7 @@ from docling_core.types.doc import (
TextItem,
)
from docling_core.types.doc.labels import CodeLanguageLabel
from PIL import Image
from PIL import Image, ImageOps
from pydantic import BaseModel
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
@ -65,7 +66,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
_model_repo_folder = "ds4sd--CodeFormula"
elements_batch_size = 5
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
expansion_factor = 0.03
expansion_factor = 0.18
def __init__(
self,
@ -124,7 +125,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.1",
revision="v1.0.2",
)
return Path(download_path)
@ -175,7 +176,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
- The second element is the extracted language if a match is found;
otherwise, `None`.
"""
pattern = r"^<_([^>]+)_>\s*(.*)"
pattern = r"^<_([^_>]+)_>\s(.*)"
match = re.match(pattern, input_string, flags=re.DOTALL)
if match:
language = str(match.group(1)) # the captured programming language
@ -206,6 +207,82 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
except ValueError:
return CodeLanguageLabel.UNKNOWN
def _get_most_frequent_edge_color(self, pil_img: Image.Image):
"""
Compute the most frequent color along the outer edges of a PIL image.
Parameters
----------
pil_img : Image.Image
A PIL Image in any mode (L, RGB, RGBA, etc.).
Returns
-------
(int) or (tuple): The most common edge color as a scalar (for grayscale) or
tuple (for RGB/RGBA).
"""
# Convert to NumPy array for easy pixel access
img_np = np.array(pil_img)
if img_np.ndim == 2:
# Grayscale-like image: shape (H, W)
# Extract edges: top row, bottom row, left col, right col
top = img_np[0, :] # shape (W,)
bottom = img_np[-1, :] # shape (W,)
left = img_np[:, 0] # shape (H,)
right = img_np[:, -1] # shape (H,)
# Concatenate all edges
edges = np.concatenate([top, bottom, left, right])
# Count frequencies
freq = Counter(edges.tolist())
most_common_value, _ = freq.most_common(1)[0]
return int(most_common_value) # single channel color
else:
# Color image: shape (H, W, C)
top = img_np[0, :, :] # shape (W, C)
bottom = img_np[-1, :, :] # shape (W, C)
left = img_np[:, 0, :] # shape (H, C)
right = img_np[:, -1, :] # shape (H, C)
# Concatenate edges along first axis
edges = np.concatenate([top, bottom, left, right], axis=0)
# Convert each color to a tuple for counting
edges_as_tuples = [tuple(pixel) for pixel in edges]
freq = Counter(edges_as_tuples)
most_common_value, _ = freq.most_common(1)[0]
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
def _pad_with_most_frequent_edge_color(
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
):
"""
Pads an image (PIL or NumPy array) using the most frequent edge color.
Parameters
----------
img : Union[Image.Image, np.ndarray]
The original image.
padding : tuple
Padding (left, top, right, bottom) in pixels.
Returns
-------
Image.Image: A new PIL image with the specified padding.
"""
if isinstance(img, np.ndarray):
pil_img = Image.fromarray(img)
else:
pil_img = img
most_freq_color = self._get_most_frequent_edge_color(pil_img)
padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
return padded_img
def __call__(
self,
doc: DoclingDocument,
@ -238,7 +315,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
assert isinstance(el.item, TextItem)
elements.append(el.item)
labels.append(el.item.label)
images.append(el.image)
images.append(
self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
)
outputs = self.code_formula_model.predict(images, labels)

View File

@ -113,7 +113,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
repo_id="ds4sd/DocumentFigureClassifier",
force_download=force,
local_dir=local_dir,
revision="v1.0.0",
revision="v1.0.1",
)
return Path(download_path)