Docling/docling/models/base_ocr_model.py
Christoph Auer e94d317c02
feat: Add adaptive OCR, factor out treatment of OCR areas and cell filtering (#38)
* Introduce adaptive OCR

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Factor out BaseOcrModel, add docling-parse backend tests, fixes

* Make easyocr default dep

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
2024-08-20 15:28:03 +02:00

125 lines
4.4 KiB
Python

import copy
import logging
from abc import abstractmethod
from typing import Iterable, List, Tuple
import numpy
import numpy as np
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
_log = logging.getLogger(__name__)
class BaseOcrModel:
def __init__(self, config):
self.config = config
self.enabled = config["enabled"]
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:
BITMAP_COVERAGE_TRESHOLD = 0.75
def find_ocr_rects(size, bitmap_rects):
image = Image.new(
"1", (round(size.width), round(size.height))
) # '1' mode is binary
# Draw all bitmap rects into a binary image
draw = ImageDraw.Draw(image)
for rect in bitmap_rects:
x0, y0, x1, y1 = rect.as_tuple()
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
draw.rectangle([(x0, y0), (x1, y1)], fill=1)
np_image = np.array(image)
# Find the connected components
labeled_image, num_features = label(
np_image > 0
) # Label black (0 value) regions
# Find enclosing bounding boxes for each connected component.
slices = find_objects(labeled_image)
bounding_boxes = [
BoundingBox(
l=slc[1].start,
t=slc[0].start,
r=slc[1].stop - 1,
b=slc[0].stop - 1,
coord_origin=CoordOrigin.TOPLEFT,
)
for slc in slices
]
# Compute area fraction on page covered by bitmaps
area_frac = np.sum(np_image > 0) / (size.width * size.height)
return (area_frac, bounding_boxes) # fraction covered # boxes
bitmap_rects = page._backend.get_bitmap_rects()
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
# return full-page rectangle if sufficiently covered with bitmaps
if coverage > BITMAP_COVERAGE_TRESHOLD:
return [
BoundingBox(
l=0,
t=0,
r=page.size.width,
b=page.size.height,
coord_origin=CoordOrigin.TOPLEFT,
)
]
# return individual rectangles if the bitmap coverage is smaller
elif coverage < BITMAP_COVERAGE_TRESHOLD:
return ocr_rects
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
# Create R-tree index for programmatic cells
p = index.Property()
p.dimension = 2
idx = index.Index(properties=p)
for i, cell in enumerate(programmatic_cells):
idx.insert(i, cell.bbox.as_tuple())
def is_overlapping_with_existing_cells(ocr_cell):
# Query the R-tree to get overlapping rectangles
possible_matches_index = list(idx.intersection(ocr_cell.bbox.as_tuple()))
return (
len(possible_matches_index) > 0
) # this is a weak criterion but it works.
filtered_ocr_cells = [
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
]
return filtered_ocr_cells
def draw_ocr_rects_and_cells(self, page, ocr_rects):
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image, "RGBA")
# Draw OCR rectangles as yellow filled rect
for rect in ocr_rects:
x0, y0, x1, y1 = rect.as_tuple()
shade_color = (255, 255, 0, 40) # transparent yellow
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
# Draw OCR and programmatic cells
for tc in page.cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
color = "red"
if isinstance(tc, OcrCell):
color = "magenta"
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
image.show()
@abstractmethod
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
pass