Docling/docling/models/base_ocr_model.py
Christoph Auer 7d3302cb48
feat: Make Page.parsed_page the only source of truth for text cells, add OCR cells to it (#1745)
* Keep page.parsed_page.textline_cells and page.cells in sync, including OCR

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Make page.parsed_page the only source of truth for text cells

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Small fix

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Correctly compute PDF boxes from pymupdf

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Use different OCR engine order

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add type hints and fix mypy

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* One more test fix

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Remove with pypdfium2_lock from caller sites

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix typing

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
2025-06-13 19:01:55 +02:00

224 lines
7.8 KiB
Python

import copy
import logging
from abc import abstractmethod
from collections.abc import Iterable
from pathlib import Path
from typing import List, Optional, Type
import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import binary_dilation, find_objects, label
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseModelWithOptions, BasePageModel
_log = logging.getLogger(__name__)
class BaseOcrModel(BasePageModel, BaseModelWithOptions):
def __init__(
self,
*,
enabled: bool,
artifacts_path: Optional[Path],
options: OcrOptions,
accelerator_options: AcceleratorOptions,
):
self.enabled = enabled
self.options = options
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> List[BoundingBox]:
BITMAP_COVERAGE_TRESHOLD = 0.75
assert page.size is not None
def find_ocr_rects(size, bitmap_rects):
image = Image.new(
"1", (round(size.width), round(size.height))
) # '1' mode is binary
# Draw all bitmap rects into a binary image
draw = ImageDraw.Draw(image)
for rect in bitmap_rects:
x0, y0, x1, y1 = rect.as_tuple()
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
draw.rectangle([(x0, y0), (x1, y1)], fill=1)
np_image = np.array(image)
# Dilate the image by 10 pixels to merge nearby bitmap rectangles
structure = np.ones(
(20, 20)
) # Create a 20x20 structure element (10 pixels in all directions)
np_image = binary_dilation(np_image > 0, structure=structure)
# Find the connected components
labeled_image, num_features = label(
np_image > 0
) # Label black (0 value) regions
# Find enclosing bounding boxes for each connected component.
slices = find_objects(labeled_image)
bounding_boxes = [
BoundingBox(
l=slc[1].start,
t=slc[0].start,
r=slc[1].stop - 1,
b=slc[0].stop - 1,
coord_origin=CoordOrigin.TOPLEFT,
)
for slc in slices
]
# Compute area fraction on page covered by bitmaps
area_frac = np.sum(np_image > 0) / (size.width * size.height)
return (area_frac, bounding_boxes) # fraction covered # boxes
if page._backend is not None:
bitmap_rects = page._backend.get_bitmap_rects()
else:
bitmap_rects = []
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
# return full-page rectangle if page is dominantly covered with bitmaps
if self.options.force_full_page_ocr or coverage > max(
BITMAP_COVERAGE_TRESHOLD, self.options.bitmap_area_threshold
):
return [
BoundingBox(
l=0,
t=0,
r=page.size.width,
b=page.size.height,
coord_origin=CoordOrigin.TOPLEFT,
)
]
# return individual rectangles if the bitmap coverage is above the threshold
elif coverage > self.options.bitmap_area_threshold:
return ocr_rects
else: # overall coverage of bitmaps is too low, drop all bitmap rectangles.
return []
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def _filter_ocr_cells(
self, ocr_cells: List[TextCell], programmatic_cells: List[TextCell]
) -> List[TextCell]:
# Create R-tree index for programmatic cells
p = index.Property()
p.dimension = 2
idx = index.Index(properties=p)
for i, cell in enumerate(programmatic_cells):
idx.insert(i, cell.rect.to_bounding_box().as_tuple())
def is_overlapping_with_existing_cells(ocr_cell):
# Query the R-tree to get overlapping rectangles
possible_matches_index = list(
idx.intersection(ocr_cell.rect.to_bounding_box().as_tuple())
)
return (
len(possible_matches_index) > 0
) # this is a weak criterion but it works.
filtered_ocr_cells = [
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
]
return filtered_ocr_cells
def post_process_cells(self, ocr_cells: List[TextCell], page: Page) -> None:
r"""
Post-process the OCR cells and update the page object.
Updates parsed_page.textline_cells directly since page.cells is now read-only.
"""
# Get existing cells from the read-only property
existing_cells = page.cells
# Combine existing and OCR cells with overlap filtering
final_cells = self._combine_cells(existing_cells, ocr_cells)
assert page.parsed_page is not None
# Update parsed_page.textline_cells directly
page.parsed_page.textline_cells = final_cells
page.parsed_page.has_lines = len(final_cells) > 0
def _combine_cells(
self, existing_cells: List[TextCell], ocr_cells: List[TextCell]
) -> List[TextCell]:
"""Combine existing and OCR cells with filtering and re-indexing."""
if self.options.force_full_page_ocr:
combined = ocr_cells
else:
filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, existing_cells)
combined = list(existing_cells) + filtered_ocr_cells
# Re-index in-place
for i, cell in enumerate(combined):
cell.index = i
return combined
def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False):
image = copy.deepcopy(page.image)
scale_x = image.width / page.size.width
scale_y = image.height / page.size.height
draw = ImageDraw.Draw(image, "RGBA")
# Draw OCR rectangles as yellow filled rect
for rect in ocr_rects:
x0, y0, x1, y1 = rect.as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
shade_color = (255, 255, 0, 40) # transparent yellow
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
# Draw OCR and programmatic cells
for tc in page.cells:
x0, y0, x1, y1 = tc.rect.to_bounding_box().as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
if y1 <= y0:
y1, y0 = y0, y1
color = "magenta" if tc.from_ocr else "gray"
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
if show:
image.show()
else:
out_path: Path = (
Path(settings.debug.debug_output_path)
/ f"debug_{conv_res.input.file.stem}"
)
out_path.mkdir(parents=True, exist_ok=True)
out_file = out_path / f"ocr_page_{page.page_no:05}.png"
image.save(str(out_file), format="png")
@abstractmethod
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
pass
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[OcrOptions]:
pass