perf: New revision code formula model and document picture classifier (#1140)
* new version code formula model Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com> * new version document picture classifier Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com> * new code formula model Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com> * restored original code formula test pdf Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com> --------- Signed-off-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com> Co-authored-by: Matteo-Omenetti <Matteo.Omenetti1@ibm.com>
This commit is contained in:
parent
4d64c4c0b6
commit
5e30381c0d
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ from docling_core.types.doc import (
|
|||||||
TextItem,
|
TextItem,
|
||||||
)
|
)
|
||||||
from docling_core.types.doc.labels import CodeLanguageLabel
|
from docling_core.types.doc.labels import CodeLanguageLabel
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
||||||
@ -65,7 +66,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
_model_repo_folder = "ds4sd--CodeFormula"
|
_model_repo_folder = "ds4sd--CodeFormula"
|
||||||
elements_batch_size = 5
|
elements_batch_size = 5
|
||||||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
||||||
expansion_factor = 0.03
|
expansion_factor = 0.18
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -124,7 +125,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
repo_id="ds4sd/CodeFormula",
|
repo_id="ds4sd/CodeFormula",
|
||||||
force_download=force,
|
force_download=force,
|
||||||
local_dir=local_dir,
|
local_dir=local_dir,
|
||||||
revision="v1.0.1",
|
revision="v1.0.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
return Path(download_path)
|
return Path(download_path)
|
||||||
@ -175,7 +176,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
- The second element is the extracted language if a match is found;
|
- The second element is the extracted language if a match is found;
|
||||||
otherwise, `None`.
|
otherwise, `None`.
|
||||||
"""
|
"""
|
||||||
pattern = r"^<_([^>]+)_>\s*(.*)"
|
pattern = r"^<_([^_>]+)_>\s(.*)"
|
||||||
match = re.match(pattern, input_string, flags=re.DOTALL)
|
match = re.match(pattern, input_string, flags=re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
language = str(match.group(1)) # the captured programming language
|
language = str(match.group(1)) # the captured programming language
|
||||||
@ -206,6 +207,82 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return CodeLanguageLabel.UNKNOWN
|
return CodeLanguageLabel.UNKNOWN
|
||||||
|
|
||||||
|
def _get_most_frequent_edge_color(self, pil_img: Image.Image):
|
||||||
|
"""
|
||||||
|
Compute the most frequent color along the outer edges of a PIL image.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pil_img : Image.Image
|
||||||
|
A PIL Image in any mode (L, RGB, RGBA, etc.).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(int) or (tuple): The most common edge color as a scalar (for grayscale) or
|
||||||
|
tuple (for RGB/RGBA).
|
||||||
|
"""
|
||||||
|
# Convert to NumPy array for easy pixel access
|
||||||
|
img_np = np.array(pil_img)
|
||||||
|
|
||||||
|
if img_np.ndim == 2:
|
||||||
|
# Grayscale-like image: shape (H, W)
|
||||||
|
# Extract edges: top row, bottom row, left col, right col
|
||||||
|
top = img_np[0, :] # shape (W,)
|
||||||
|
bottom = img_np[-1, :] # shape (W,)
|
||||||
|
left = img_np[:, 0] # shape (H,)
|
||||||
|
right = img_np[:, -1] # shape (H,)
|
||||||
|
|
||||||
|
# Concatenate all edges
|
||||||
|
edges = np.concatenate([top, bottom, left, right])
|
||||||
|
|
||||||
|
# Count frequencies
|
||||||
|
freq = Counter(edges.tolist())
|
||||||
|
most_common_value, _ = freq.most_common(1)[0]
|
||||||
|
return int(most_common_value) # single channel color
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Color image: shape (H, W, C)
|
||||||
|
top = img_np[0, :, :] # shape (W, C)
|
||||||
|
bottom = img_np[-1, :, :] # shape (W, C)
|
||||||
|
left = img_np[:, 0, :] # shape (H, C)
|
||||||
|
right = img_np[:, -1, :] # shape (H, C)
|
||||||
|
|
||||||
|
# Concatenate edges along first axis
|
||||||
|
edges = np.concatenate([top, bottom, left, right], axis=0)
|
||||||
|
|
||||||
|
# Convert each color to a tuple for counting
|
||||||
|
edges_as_tuples = [tuple(pixel) for pixel in edges]
|
||||||
|
freq = Counter(edges_as_tuples)
|
||||||
|
most_common_value, _ = freq.most_common(1)[0]
|
||||||
|
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
|
||||||
|
|
||||||
|
def _pad_with_most_frequent_edge_color(
|
||||||
|
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pads an image (PIL or NumPy array) using the most frequent edge color.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
img : Union[Image.Image, np.ndarray]
|
||||||
|
The original image.
|
||||||
|
padding : tuple
|
||||||
|
Padding (left, top, right, bottom) in pixels.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Image.Image: A new PIL image with the specified padding.
|
||||||
|
"""
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
pil_img = Image.fromarray(img)
|
||||||
|
else:
|
||||||
|
pil_img = img
|
||||||
|
|
||||||
|
most_freq_color = self._get_most_frequent_edge_color(pil_img)
|
||||||
|
|
||||||
|
padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
|
||||||
|
return padded_img
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
doc: DoclingDocument,
|
doc: DoclingDocument,
|
||||||
@ -238,7 +315,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|||||||
assert isinstance(el.item, TextItem)
|
assert isinstance(el.item, TextItem)
|
||||||
elements.append(el.item)
|
elements.append(el.item)
|
||||||
labels.append(el.item.label)
|
labels.append(el.item.label)
|
||||||
images.append(el.image)
|
images.append(
|
||||||
|
self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self.code_formula_model.predict(images, labels)
|
outputs = self.code_formula_model.predict(images, labels)
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|||||||
repo_id="ds4sd/DocumentFigureClassifier",
|
repo_id="ds4sd/DocumentFigureClassifier",
|
||||||
force_download=force,
|
force_download=force,
|
||||||
local_dir=local_dir,
|
local_dir=local_dir,
|
||||||
revision="v1.0.0",
|
revision="v1.0.1",
|
||||||
)
|
)
|
||||||
|
|
||||||
return Path(download_path)
|
return Path(download_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user