Ensure all models work only on valid pages (#158)
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -71,92 +71,101 @@ class TableStructureModel(BasePageModel):
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
assert page.predictions.layout is not None
|
||||
assert page.size is not None
|
||||
|
||||
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
||||
|
||||
in_tables = [
|
||||
(
|
||||
cluster,
|
||||
[
|
||||
round(cluster.bbox.l) * self.scale,
|
||||
round(cluster.bbox.t) * self.scale,
|
||||
round(cluster.bbox.r) * self.scale,
|
||||
round(cluster.bbox.b) * self.scale,
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == DocItemLabel.TABLE
|
||||
]
|
||||
if not len(in_tables):
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
continue
|
||||
else:
|
||||
|
||||
tokens = []
|
||||
for c in page.cells:
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
|
||||
assert page.predictions.layout is not None
|
||||
assert page.size is not None
|
||||
|
||||
tokens.append(new_cell.model_dump())
|
||||
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
||||
|
||||
page_input = {
|
||||
"tokens": tokens,
|
||||
"width": page.size.width * self.scale,
|
||||
"height": page.size.height * self.scale,
|
||||
}
|
||||
page_input["image"] = numpy.asarray(page.get_image(scale=self.scale))
|
||||
in_tables = [
|
||||
(
|
||||
cluster,
|
||||
[
|
||||
round(cluster.bbox.l) * self.scale,
|
||||
round(cluster.bbox.t) * self.scale,
|
||||
round(cluster.bbox.r) * self.scale,
|
||||
round(cluster.bbox.b) * self.scale,
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == DocItemLabel.TABLE
|
||||
]
|
||||
if not len(in_tables):
|
||||
yield page
|
||||
continue
|
||||
|
||||
table_clusters, table_bboxes = zip(*in_tables)
|
||||
tokens = []
|
||||
for c in page.cells:
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox)
|
||||
/ c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(
|
||||
scale=self.scale
|
||||
)
|
||||
|
||||
if len(table_bboxes):
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
page_input, table_bboxes, do_matching=self.do_cell_matching
|
||||
)
|
||||
tokens.append(new_cell.model_dump())
|
||||
|
||||
for table_cluster, table_out in zip(table_clusters, tf_output):
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
page_input = {
|
||||
"tokens": tokens,
|
||||
"width": page.size.width * self.scale,
|
||||
"height": page.size.height * self.scale,
|
||||
}
|
||||
page_input["image"] = numpy.asarray(page.get_image(scale=self.scale))
|
||||
|
||||
if not self.do_cell_matching:
|
||||
the_bbox = BoundingBox.model_validate(
|
||||
element["bbox"]
|
||||
).scaled(1 / self.scale)
|
||||
text_piece = page._backend.get_text_in_rect(the_bbox)
|
||||
element["bbox"]["token"] = text_piece
|
||||
table_clusters, table_bboxes = zip(*in_tables)
|
||||
|
||||
tc = TableCell.model_validate(element)
|
||||
if self.do_cell_matching and tc.bbox is not None:
|
||||
tc.bbox = tc.bbox.scaled(1 / self.scale)
|
||||
table_cells.append(tc)
|
||||
|
||||
# Retrieving cols/rows, after post processing:
|
||||
num_rows = table_out["predict_details"]["num_rows"]
|
||||
num_cols = table_out["predict_details"]["num_cols"]
|
||||
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
||||
|
||||
tbl = Table(
|
||||
otsl_seq=otsl_seq,
|
||||
table_cells=table_cells,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols,
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label=DocItemLabel.TABLE,
|
||||
if len(table_bboxes):
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
page_input, table_bboxes, do_matching=self.do_cell_matching
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
||||
for table_cluster, table_out in zip(table_clusters, tf_output):
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
|
||||
# For debugging purposes:
|
||||
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
|
||||
if not self.do_cell_matching:
|
||||
the_bbox = BoundingBox.model_validate(
|
||||
element["bbox"]
|
||||
).scaled(1 / self.scale)
|
||||
text_piece = page._backend.get_text_in_rect(the_bbox)
|
||||
element["bbox"]["token"] = text_piece
|
||||
|
||||
yield page
|
||||
tc = TableCell.model_validate(element)
|
||||
if self.do_cell_matching and tc.bbox is not None:
|
||||
tc.bbox = tc.bbox.scaled(1 / self.scale)
|
||||
table_cells.append(tc)
|
||||
|
||||
# Retrieving cols/rows, after post processing:
|
||||
num_rows = table_out["predict_details"]["num_rows"]
|
||||
num_cols = table_out["predict_details"]["num_cols"]
|
||||
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
||||
|
||||
tbl = Table(
|
||||
otsl_seq=otsl_seq,
|
||||
table_cells=table_cells,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols,
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label=DocItemLabel.TABLE,
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = (
|
||||
tbl
|
||||
)
|
||||
|
||||
# For debugging purposes:
|
||||
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
|
||||
|
||||
yield page
|
||||
|
||||
Reference in New Issue
Block a user