Docling/docling/models/base_ocr_model.py
Michele Dolfi b346faf622
feat: add coverage_threshold to skip OCR for small images (#161)
* feat: add coverage_threshold to skip OCR for small images

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* filter individual boxes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename option

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
2024-10-18 13:58:23 +02:00

138 lines
4.9 KiB
Python

import copy
import logging
from abc import abstractmethod
from typing import Iterable, List
import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label
from docling.datamodel.base_models import OcrCell, Page
from docling.datamodel.pipeline_options import OcrOptions
_log = logging.getLogger(__name__)
class BaseOcrModel:
def __init__(self, enabled: bool, options: OcrOptions):
self.enabled = enabled
self.options = options
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> List[BoundingBox]:
BITMAP_COVERAGE_TRESHOLD = 0.75
assert page.size is not None
def find_ocr_rects(size, bitmap_rects):
image = Image.new(
"1", (round(size.width), round(size.height))
) # '1' mode is binary
# Draw all bitmap rects into a binary image
draw = ImageDraw.Draw(image)
for rect in bitmap_rects:
x0, y0, x1, y1 = rect.as_tuple()
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
draw.rectangle([(x0, y0), (x1, y1)], fill=1)
np_image = np.array(image)
# Find the connected components
labeled_image, num_features = label(
np_image > 0
) # Label black (0 value) regions
# Find enclosing bounding boxes for each connected component.
slices = find_objects(labeled_image)
bounding_boxes = [
BoundingBox(
l=slc[1].start,
t=slc[0].start,
r=slc[1].stop - 1,
b=slc[0].stop - 1,
coord_origin=CoordOrigin.TOPLEFT,
)
for slc in slices
]
# Compute area fraction on page covered by bitmaps
area_frac = np.sum(np_image > 0) / (size.width * size.height)
return (area_frac, bounding_boxes) # fraction covered # boxes
if page._backend is not None:
bitmap_rects = page._backend.get_bitmap_rects()
else:
bitmap_rects = []
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
# return full-page rectangle if sufficiently covered with bitmaps
if coverage > max(BITMAP_COVERAGE_TRESHOLD, self.options.bitmap_area_threshold):
return [
BoundingBox(
l=0,
t=0,
r=page.size.width,
b=page.size.height,
coord_origin=CoordOrigin.TOPLEFT,
)
]
# return individual rectangles if the bitmap coverage is smaller
else: # coverage <= BITMAP_COVERAGE_TRESHOLD:
# skip OCR if the bitmap area on the page is smaller than the options threshold
ocr_rects = [
rect
for rect in ocr_rects
if rect.area() / (page.size.width * page.size.height)
> self.options.bitmap_area_threshold
]
return ocr_rects
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
# Create R-tree index for programmatic cells
p = index.Property()
p.dimension = 2
idx = index.Index(properties=p)
for i, cell in enumerate(programmatic_cells):
idx.insert(i, cell.bbox.as_tuple())
def is_overlapping_with_existing_cells(ocr_cell):
# Query the R-tree to get overlapping rectangles
possible_matches_index = list(idx.intersection(ocr_cell.bbox.as_tuple()))
return (
len(possible_matches_index) > 0
) # this is a weak criterion but it works.
filtered_ocr_cells = [
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
]
return filtered_ocr_cells
def draw_ocr_rects_and_cells(self, page, ocr_rects):
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image, "RGBA")
# Draw OCR rectangles as yellow filled rect
for rect in ocr_rects:
x0, y0, x1, y1 = rect.as_tuple()
shade_color = (255, 255, 0, 40) # transparent yellow
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
# Draw OCR and programmatic cells
for tc in page.cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
color = "red"
if isinstance(tc, OcrCell):
color = "magenta"
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
image.show()
@abstractmethod
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
pass