Docling/docling/models/base_ocr_model.py
William Easton 3089cf2d26
perf: Move expensive imports closer to usage (#1863)
* Move expensive imports closer to usage

Signed-off-by: William Easton <bill.easton@elastic.co>

* DCO Remediation Commit for William Easton <bill.easton@elastic.co>

I, William Easton <bill.easton@elastic.co>, hereby add my Signed-off-by to this commit: 8a7412ce5bb131a01bb6403067aeb948c9093b0b

Signed-off-by: William Easton <bill.easton@elastic.co>

* formatting fixes

Signed-off-by: William Easton <bill.easton@elastic.co>

* DCO Remediation Commit for William Easton <bill.easton@elastic.co>

I, William Easton <bill.easton@elastic.co>, hereby add my Signed-off-by to this commit: 8a7412ce5bb131a01bb6403067aeb948c9093b0b
I, William Easton <bill.easton@elastic.co>, hereby add my Signed-off-by to this commit: 963e34325071db5e844841f10c27b396a054a0a1

Signed-off-by: William Easton <bill.easton@elastic.co>

* Fix baseocrmodel test issue

Signed-off-by: William Easton <bill.easton@elastic.co>

---------

Signed-off-by: William Easton <bill.easton@elastic.co>
2025-07-01 22:27:17 +02:00

228 lines
8.0 KiB
Python

import copy
import logging
from abc import abstractmethod
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type
import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from PIL import Image, ImageDraw
from rtree import index
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseModelWithOptions, BasePageModel
_log = logging.getLogger(__name__)
class BaseOcrModel(BasePageModel, BaseModelWithOptions):
def __init__(
self,
*,
enabled: bool,
artifacts_path: Optional[Path],
options: OcrOptions,
accelerator_options: AcceleratorOptions,
):
# Make sure any delay/error from import occurs on ocr model init and not first use
from scipy.ndimage import binary_dilation, find_objects, label
self.enabled = enabled
self.options = options
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> List[BoundingBox]:
from scipy.ndimage import binary_dilation, find_objects, label
BITMAP_COVERAGE_TRESHOLD = 0.75
assert page.size is not None
def find_ocr_rects(size, bitmap_rects):
image = Image.new(
"1", (round(size.width), round(size.height))
) # '1' mode is binary
# Draw all bitmap rects into a binary image
draw = ImageDraw.Draw(image)
for rect in bitmap_rects:
x0, y0, x1, y1 = rect.as_tuple()
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
draw.rectangle([(x0, y0), (x1, y1)], fill=1)
np_image = np.array(image)
# Dilate the image by 10 pixels to merge nearby bitmap rectangles
structure = np.ones(
(20, 20)
) # Create a 20x20 structure element (10 pixels in all directions)
np_image = binary_dilation(np_image > 0, structure=structure)
# Find the connected components
labeled_image, num_features = label(
np_image > 0
) # Label black (0 value) regions
# Find enclosing bounding boxes for each connected component.
slices = find_objects(labeled_image)
bounding_boxes = [
BoundingBox(
l=slc[1].start,
t=slc[0].start,
r=slc[1].stop - 1,
b=slc[0].stop - 1,
coord_origin=CoordOrigin.TOPLEFT,
)
for slc in slices
]
# Compute area fraction on page covered by bitmaps
area_frac = np.sum(np_image > 0) / (size.width * size.height)
return (area_frac, bounding_boxes) # fraction covered # boxes
if page._backend is not None:
bitmap_rects = page._backend.get_bitmap_rects()
else:
bitmap_rects = []
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
# return full-page rectangle if page is dominantly covered with bitmaps
if self.options.force_full_page_ocr or coverage > max(
BITMAP_COVERAGE_TRESHOLD, self.options.bitmap_area_threshold
):
return [
BoundingBox(
l=0,
t=0,
r=page.size.width,
b=page.size.height,
coord_origin=CoordOrigin.TOPLEFT,
)
]
# return individual rectangles if the bitmap coverage is above the threshold
elif coverage > self.options.bitmap_area_threshold:
return ocr_rects
else: # overall coverage of bitmaps is too low, drop all bitmap rectangles.
return []
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def _filter_ocr_cells(
self, ocr_cells: List[TextCell], programmatic_cells: List[TextCell]
) -> List[TextCell]:
# Create R-tree index for programmatic cells
p = index.Property()
p.dimension = 2
idx = index.Index(properties=p)
for i, cell in enumerate(programmatic_cells):
idx.insert(i, cell.rect.to_bounding_box().as_tuple())
def is_overlapping_with_existing_cells(ocr_cell):
# Query the R-tree to get overlapping rectangles
possible_matches_index = list(
idx.intersection(ocr_cell.rect.to_bounding_box().as_tuple())
)
return (
len(possible_matches_index) > 0
) # this is a weak criterion but it works.
filtered_ocr_cells = [
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
]
return filtered_ocr_cells
def post_process_cells(self, ocr_cells: List[TextCell], page: Page) -> None:
r"""
Post-process the OCR cells and update the page object.
Updates parsed_page.textline_cells directly since page.cells is now read-only.
"""
# Get existing cells from the read-only property
existing_cells = page.cells
# Combine existing and OCR cells with overlap filtering
final_cells = self._combine_cells(existing_cells, ocr_cells)
assert page.parsed_page is not None
# Update parsed_page.textline_cells directly
page.parsed_page.textline_cells = final_cells
page.parsed_page.has_lines = len(final_cells) > 0
def _combine_cells(
self, existing_cells: List[TextCell], ocr_cells: List[TextCell]
) -> List[TextCell]:
"""Combine existing and OCR cells with filtering and re-indexing."""
if self.options.force_full_page_ocr:
combined = ocr_cells
else:
filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, existing_cells)
combined = list(existing_cells) + filtered_ocr_cells
# Re-index in-place
for i, cell in enumerate(combined):
cell.index = i
return combined
def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False):
image = copy.deepcopy(page.image)
scale_x = image.width / page.size.width
scale_y = image.height / page.size.height
draw = ImageDraw.Draw(image, "RGBA")
# Draw OCR rectangles as yellow filled rect
for rect in ocr_rects:
x0, y0, x1, y1 = rect.as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
shade_color = (255, 255, 0, 40) # transparent yellow
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
# Draw OCR and programmatic cells
for tc in page.cells:
x0, y0, x1, y1 = tc.rect.to_bounding_box().as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
if y1 <= y0:
y1, y0 = y0, y1
color = "magenta" if tc.from_ocr else "gray"
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
if show:
image.show()
else:
out_path: Path = (
Path(settings.debug.debug_output_path)
/ f"debug_{conv_res.input.file.stem}"
)
out_path.mkdir(parents=True, exist_ok=True)
out_file = out_path / f"ocr_page_{page.page_no:05}.png"
image.save(str(out_file), format="png")
@abstractmethod
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
pass
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[OcrOptions]:
pass