feat: Optimize table extraction quality, add configuration options (#11)

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Christoph Auer 2024-07-17 16:13:21 +02:00 committed by GitHub
parent 3e2ede8107
commit e9526bb11e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 87 additions and 27 deletions

View File

@ -47,7 +47,9 @@ python examples/convert.py
```
The output of the above command will be written to `./scratch`.
### Enable or disable pipeline features
### Adjust pipeline features
**Control pipeline options**
You can control if table structure recognition or OCR should be performed by arguments passed to `DocumentConverter`:
```python
@ -60,6 +62,23 @@ doc_converter = DocumentConverter(
)
```
**Control table extraction options**
You can control if table structure recognition should map the recognized structure back to PDF cells (default) or use text cells from the structure prediction itself.
This can improve output quality if you find that multiple columns in extracted tables are erroneously merged into one.
```python
pipeline_options = PipelineOptions(do_table_structure=True)
pipeline_options.table_structure_options.do_cell_matching = False # Uses text cells predicted from table structure model
doc_converter = DocumentConverter(
artifacts_path=artifacts_path,
pipeline_options=pipeline_options,
)
```
### Impose limits on the document size
You can limit the file size and number of pages which should be allowed to process per document:

View File

@ -1,3 +1,4 @@
import copy
from enum import Enum, auto
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
@ -47,6 +48,15 @@ class BoundingBox(BaseModel):
def height(self):
return abs(self.t - self.b)
def scaled(self, scale: float) -> "BoundingBox":
out_bbox = copy.deepcopy(self)
out_bbox.l *= scale
out_bbox.r *= scale
out_bbox.t *= scale
out_bbox.b *= scale
return out_bbox
def as_tuple(self):
if self.coord_origin == CoordOrigin.TOPLEFT:
return (self.l, self.t, self.r, self.b)
@ -241,6 +251,17 @@ class DocumentStream(BaseModel):
stream: BytesIO
class TableStructureOptions(BaseModel):
do_cell_matching: bool = (
True
# True: Matches predictions back to PDF cells. Can break table output if PDF cells
# are merged across table columns.
# False: Let table structure model define the text cells, ignore PDF cells.
)
class PipelineOptions(BaseModel):
do_table_structure: bool = True
do_ocr: bool = False
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = False # True: perform OCR, replace programmatic PDF text
table_structure_options: TableStructureOptions = TableStructureOptions()

View File

@ -19,18 +19,6 @@ class PageAssembleModel:
def __init__(self, config):
self.config = config
# self.line_wrap_pattern = re.compile(r'(?<=[^\W_])- \n(?=\w)')
# def sanitize_text_poor(self, lines):
# text = '\n'.join(lines)
#
# # treat line wraps.
# sanitized_text = self.line_wrap_pattern.sub('', text)
#
# sanitized_text = sanitized_text.replace('\n', ' ')
#
# return sanitized_text
def sanitize_text(self, lines):
if len(lines) <= 1:
return " ".join(lines)

View File

@ -1,7 +1,10 @@
from typing import Iterable
import copy
import random
from typing import Iterable, List
import numpy
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw
from docling.datamodel.base_models import (
BoundingBox,
@ -28,6 +31,21 @@ class TableStructureModel:
self.tm_model_type = self.tm_config["model"]["type"]
self.tf_predictor = TFPredictor(self.tm_config)
self.scale = 2.0 # Scale up table input images to 144 dpi
def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]):
image = page._backend.get_page_image()
draw = ImageDraw.Draw(image)
for table_element in tbl_list:
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
for tc in table_element.table_cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="blue")
image.show()
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
@ -36,16 +54,17 @@ class TableStructureModel:
return
for page in page_batch:
page.predictions.tablestructure = TableStructurePrediction() # dummy
in_tables = [
(
cluster,
[
round(cluster.bbox.l),
round(cluster.bbox.t),
round(cluster.bbox.r),
round(cluster.bbox.b),
round(cluster.bbox.l) * self.scale,
round(cluster.bbox.t) * self.scale,
round(cluster.bbox.r) * self.scale,
round(cluster.bbox.b) * self.scale,
],
)
for cluster in page.predictions.layout.clusters
@ -65,20 +84,29 @@ class TableStructureModel:
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
tokens.append(c.model_dump())
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
iocr_page = {
"image": numpy.asarray(page.image),
tokens.append(new_cell.model_dump())
page_input = {
"tokens": tokens,
"width": page.size.width,
"height": page.size.height,
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
}
# add image to page input.
if self.scale == 1.0:
page_input["image"] = numpy.asarray(page.image)
else: # render new page image on the fly at desired scale
page_input["image"] = numpy.asarray(
page._backend.get_page_image(scale=self.scale)
)
table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict(
iocr_page, table_bboxes, do_matching=self.do_cell_matching
page_input, table_bboxes, do_matching=self.do_cell_matching
)
for table_cluster, table_out in zip(table_clusters, tf_output):
@ -91,6 +119,7 @@ class TableStructureModel:
element["bbox"]["token"] = text_piece
tc = TableCell.model_validate(element)
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc)
# Retrieving cols/rows, after post processing:
@ -111,4 +140,7 @@ class TableStructureModel:
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
# For debugging purposes:
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
yield page

View File

@ -34,7 +34,7 @@ class StandardModelPipeline(BaseModelPipeline):
"artifacts_path": artifacts_path
/ StandardModelPipeline._table_model_path,
"enabled": pipeline_options.do_table_structure,
"do_cell_matching": False,
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
}
),
]