Initial commit
This commit is contained in:
114
docling/models/table_structure_model.py
Normal file
114
docling/models/table_structure_model.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Page,
|
||||
TableCell,
|
||||
TableElement,
|
||||
TableStructurePrediction,
|
||||
)
|
||||
|
||||
|
||||
class TableStructureModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.do_cell_matching = config["do_cell_matching"]
|
||||
|
||||
self.enabled = config["enabled"]
|
||||
if self.enabled:
|
||||
artifacts_path = config["artifacts_path"]
|
||||
# Third Party
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
|
||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||
self.tm_model_type = self.tm_config["model"]["type"]
|
||||
|
||||
self.tf_predictor = TFPredictor(self.tm_config)
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
||||
|
||||
in_tables = [
|
||||
(
|
||||
cluster,
|
||||
[
|
||||
round(cluster.bbox.l),
|
||||
round(cluster.bbox.t),
|
||||
round(cluster.bbox.r),
|
||||
round(cluster.bbox.b),
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == "Table"
|
||||
]
|
||||
if not len(in_tables):
|
||||
yield page
|
||||
continue
|
||||
|
||||
tokens = []
|
||||
for c in page.cells:
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
tokens.append(c.model_dump())
|
||||
|
||||
iocr_page = {
|
||||
"image": numpy.asarray(page.image),
|
||||
"tokens": tokens,
|
||||
"width": page.size.width,
|
||||
"height": page.size.height,
|
||||
}
|
||||
|
||||
table_clusters, table_bboxes = zip(*in_tables)
|
||||
|
||||
if len(table_bboxes):
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
iocr_page, table_bboxes, do_matching=self.do_cell_matching
|
||||
)
|
||||
|
||||
for table_cluster, table_out in zip(table_clusters, tf_output):
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
|
||||
if not self.do_cell_matching:
|
||||
the_bbox = BoundingBox.model_validate(element["bbox"])
|
||||
text_piece = page._backend.get_text_in_rect(the_bbox)
|
||||
element["bbox"]["token"] = text_piece
|
||||
|
||||
tc = TableCell.model_validate(element)
|
||||
table_cells.append(tc)
|
||||
|
||||
# Retrieving cols/rows, after post processing:
|
||||
num_rows = table_out["predict_details"]["num_rows"]
|
||||
num_cols = table_out["predict_details"]["num_cols"]
|
||||
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
||||
|
||||
tbl = TableElement(
|
||||
otsl_seq=otsl_seq,
|
||||
table_cells=table_cells,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols,
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label="Table",
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
||||
|
||||
yield page
|
||||
Reference in New Issue
Block a user