Initial commit
This commit is contained in:
0
docling/models/__init__.py
Normal file
0
docling/models/__init__.py
Normal file
82
docling/models/ds_glm_model.py
Normal file
82
docling/models/ds_glm_model.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
from deepsearch_glm.nlp_utils import init_nlp_model
|
||||
from deepsearch_glm.utils.ds_utils import to_legacy_document_format
|
||||
from deepsearch_glm.utils.load_pretrained_models import load_pretrained_nlp_models
|
||||
from docling_core.types import BaseText
|
||||
from docling_core.types import Document as DsDocument
|
||||
from docling_core.types import Ref
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, Cluster, CoordOrigin
|
||||
from docling.datamodel.document import ConvertedDocument
|
||||
|
||||
|
||||
class GlmModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
load_pretrained_nlp_models()
|
||||
model = init_nlp_model(model_names="language;term;reference")
|
||||
self.model = model
|
||||
|
||||
def __call__(self, document: ConvertedDocument) -> DsDocument:
|
||||
ds_doc = document.to_ds_document()
|
||||
ds_doc_dict = ds_doc.model_dump(by_alias=True)
|
||||
|
||||
glm_doc = self.model.apply_on_doc(ds_doc_dict)
|
||||
ds_doc_dict = to_legacy_document_format(
|
||||
glm_doc, ds_doc_dict, update_name_label=True
|
||||
)
|
||||
|
||||
exported_doc = DsDocument.model_validate(ds_doc_dict)
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells(ds_document, page_no):
|
||||
clusters_to_draw = []
|
||||
image = copy.deepcopy(document.pages[page_no].image)
|
||||
for ix, elem in enumerate(ds_document.main_text):
|
||||
if isinstance(elem, BaseText):
|
||||
prov = elem.prov[0]
|
||||
elif isinstance(elem, Ref):
|
||||
_, arr, index = elem.ref.split("/")
|
||||
index = int(index)
|
||||
if arr == "tables":
|
||||
prov = ds_document.tables[index].prov[0]
|
||||
elif arr == "figures":
|
||||
prov = ds_document.figures[index].prov[0]
|
||||
else:
|
||||
prov = None
|
||||
|
||||
if prov and prov.page == page_no:
|
||||
clusters_to_draw.append(
|
||||
Cluster(
|
||||
id=ix,
|
||||
label=elem.name,
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=prov.bbox,
|
||||
origin=CoordOrigin.BOTTOMLEFT,
|
||||
).to_top_left_origin(document.pages[page_no].size.height),
|
||||
)
|
||||
)
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for c in clusters_to_draw:
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
||||
draw.text((x0 + 2, y0 + 2), f"{c.id}:{c.label}", fill=(255, 0, 0, 255))
|
||||
|
||||
cell_color = (
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
)
|
||||
for tc in c.cells: # [:1]:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
||||
image.show()
|
||||
|
||||
# draw_clusters_and_cells(ds_doc, 0)
|
||||
# draw_clusters_and_cells(exported_doc, 0)
|
||||
|
||||
return exported_doc
|
||||
77
docling/models/easyocr_model.py
Normal file
77
docling/models/easyocr_model.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyOcrModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.enabled = config["enabled"]
|
||||
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
|
||||
|
||||
if self.enabled:
|
||||
import easyocr
|
||||
|
||||
self.reader = easyocr.Reader(config["lang"])
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
# rects = page._fpage.
|
||||
high_res_image = page._backend.get_page_image(scale=self.scale)
|
||||
im = numpy.array(high_res_image)
|
||||
result = self.reader.readtext(im)
|
||||
|
||||
del high_res_image
|
||||
del im
|
||||
|
||||
cells = [
|
||||
OcrCell(
|
||||
id=ix,
|
||||
text=line[1],
|
||||
confidence=line[2],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=(
|
||||
line[0][0][0] / self.scale,
|
||||
line[0][0][1] / self.scale,
|
||||
line[0][2][0] / self.scale,
|
||||
line[0][2][1] / self.scale,
|
||||
),
|
||||
origin=CoordOrigin.TOPLEFT,
|
||||
),
|
||||
)
|
||||
for ix, line in enumerate(result)
|
||||
]
|
||||
|
||||
page.cells = cells # For now, just overwrites all digital cells.
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells():
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
cell_color = (
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
)
|
||||
for tc in cells:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
||||
image.show()
|
||||
|
||||
# draw_clusters_and_cells()
|
||||
|
||||
yield page
|
||||
318
docling/models/layout_model.py
Normal file
318
docling/models/layout_model.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Iterable, List
|
||||
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Cell,
|
||||
Cluster,
|
||||
CoordOrigin,
|
||||
LayoutPrediction,
|
||||
Page,
|
||||
)
|
||||
from docling.utils import layout_utils as lu
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayoutModel:
|
||||
|
||||
TEXT_ELEM_LABELS = [
|
||||
"Text",
|
||||
"Footnote",
|
||||
"Caption",
|
||||
"Checkbox-Unselected",
|
||||
"Checkbox-Selected",
|
||||
"Section-header",
|
||||
"Page-header",
|
||||
"Page-footer",
|
||||
"Code",
|
||||
"List-item",
|
||||
# "Formula",
|
||||
]
|
||||
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"]
|
||||
|
||||
TABLE_LABEL = "Table"
|
||||
FIGURE_LABEL = "Picture"
|
||||
FORMULA_LABEL = "Formula"
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
config["artifacts_path"]
|
||||
) # TODO temporary
|
||||
|
||||
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
CLASS_THRESHOLDS = {
|
||||
"Caption": 0.35,
|
||||
"Footnote": 0.35,
|
||||
"Formula": 0.35,
|
||||
"List-item": 0.35,
|
||||
"Page-footer": 0.35,
|
||||
"Page-header": 0.35,
|
||||
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples.
|
||||
"Section-header": 0.45,
|
||||
"Table": 0.35,
|
||||
"Text": 0.45,
|
||||
"Title": 0.45,
|
||||
"Document Index": 0.45,
|
||||
"Code": 0.45,
|
||||
"Checkbox-Selected": 0.45,
|
||||
"Checkbox-Unselected": 0.45,
|
||||
"Form": 0.45,
|
||||
"Key-Value Region": 0.45,
|
||||
}
|
||||
|
||||
_log.debug("================= Start postprocess function ====================")
|
||||
start_time = time.time()
|
||||
# Apply Confidence Threshold to cluster predictions
|
||||
# confidence = self.conf_threshold
|
||||
clusters_out = []
|
||||
|
||||
for cluster in clusters:
|
||||
confidence = CLASS_THRESHOLDS[cluster.label]
|
||||
if cluster.confidence >= confidence:
|
||||
# annotation["created_by"] = "high_conf_pred"
|
||||
clusters_out.append(cluster)
|
||||
|
||||
# map to dictionary clusters and cells, with bottom left origin
|
||||
clusters = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
||||
), # TODO
|
||||
"confidence": c.confidence,
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters
|
||||
]
|
||||
|
||||
clusters_out = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
||||
), # TODO
|
||||
"confidence": c.confidence,
|
||||
"created_by": "high_conf_pred",
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters_out
|
||||
]
|
||||
|
||||
raw_cells = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
||||
), # TODO
|
||||
"text": c.text,
|
||||
}
|
||||
for c in cells
|
||||
]
|
||||
cell_count = len(raw_cells)
|
||||
|
||||
_log.debug("---- 0. Treat cluster overlaps ------")
|
||||
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
|
||||
|
||||
_log.debug(
|
||||
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
|
||||
)
|
||||
## Check for cells included in or touched by clusters:
|
||||
clusters_out = lu.assigning_cell_ids_to_clusters(
|
||||
clusters_out, raw_cells, MIN_INTERSECTION
|
||||
)
|
||||
|
||||
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
|
||||
# Creates a map of cell_id->cluster_id
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
# Assign orphan cells with lower confidence predictions
|
||||
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
||||
clusters_out, clusters, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
||||
clusters_out = lu.assigning_cell_ids_to_clusters(
|
||||
clusters_out, raw_cells, MIN_INTERSECTION
|
||||
)
|
||||
|
||||
_log.debug("---- 3. Settle Ambigous Cells")
|
||||
# Creates an update map after assignment of cell_id->cluster_id
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
# Settle pdf cells that belong to multiple clusters
|
||||
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
|
||||
clusters_out, raw_cells, ambiguous_cell_indices
|
||||
)
|
||||
|
||||
_log.debug("---- 4. Set Orphans as Text")
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
||||
clusters_out, clusters, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
||||
# Merge cells orphan cells
|
||||
clusters_out = lu.merge_cells(clusters_out)
|
||||
|
||||
# Clean up clusters that remain from merged and unreasonable clusters
|
||||
clusters_out = lu.clean_up_clusters(
|
||||
clusters_out,
|
||||
raw_cells,
|
||||
merge_cells=True,
|
||||
img_table=True,
|
||||
one_cell_table=True,
|
||||
)
|
||||
|
||||
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
|
||||
clusters_out = new_clusters
|
||||
|
||||
## We first rebuild where every cell is now:
|
||||
## Now we write into a prediction cells list, not into the raw cells list.
|
||||
## As we don't need previous labels, we best overwrite any old list, because that might
|
||||
## have been sorted differently.
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
target_cells = []
|
||||
for ix, cell in enumerate(raw_cells):
|
||||
new_cell = {
|
||||
"id": ix,
|
||||
"rawcell_id": ix,
|
||||
"label": "None",
|
||||
"bbox": cell["bbox"],
|
||||
"text": cell["text"],
|
||||
}
|
||||
for cluster_index in clusters_around_cells[
|
||||
ix
|
||||
]: # By previous analysis, this is always 1 cluster.
|
||||
new_cell["label"] = clusters_out[cluster_index]["type"]
|
||||
target_cells.append(new_cell)
|
||||
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
|
||||
cells_out = target_cells
|
||||
|
||||
## -------------------------------
|
||||
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
|
||||
_log.debug("---- 5. Sort clusters in reading order ------")
|
||||
sorted_clusters = lu.produce_reading_order(
|
||||
clusters_out, "raw_cell_ids", "raw_cell_ids", True
|
||||
)
|
||||
clusters_out = sorted_clusters
|
||||
|
||||
# end_time = timer()
|
||||
_log.debug("---- End of postprocessing function ------")
|
||||
end_time = time.time() - start_time
|
||||
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
||||
|
||||
cells_out = [
|
||||
Cell(
|
||||
id=c["id"],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
||||
).to_top_left_origin(page_height),
|
||||
text=c["text"],
|
||||
)
|
||||
for c in cells_out
|
||||
]
|
||||
clusters_out_new = []
|
||||
for c in clusters_out:
|
||||
cluster_cells = [ccell for ccell in cells_out if ccell.id in c["cell_ids"]]
|
||||
c_new = Cluster(
|
||||
id=c["id"],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
||||
).to_top_left_origin(page_height),
|
||||
confidence=c["confidence"],
|
||||
label=c["type"],
|
||||
cells=cluster_cells,
|
||||
)
|
||||
clusters_out_new.append(c_new)
|
||||
|
||||
return clusters_out_new, cells_out
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(self.layout_predictor.predict(page.image)):
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
label=pred_item["label"],
|
||||
confidence=pred_item["confidence"],
|
||||
bbox=BoundingBox.model_validate(pred_item),
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
|
||||
# Map cells to clusters
|
||||
# TODO: Remove, postprocess should take care of it anyway.
|
||||
for cell in page.cells:
|
||||
for cluster in clusters:
|
||||
if not cell.bbox.area() > 0:
|
||||
overlap_frac = 0.0
|
||||
else:
|
||||
overlap_frac = (
|
||||
cell.bbox.intersection_area_with(cluster.bbox)
|
||||
/ cell.bbox.area()
|
||||
)
|
||||
|
||||
if overlap_frac > 0.5:
|
||||
cluster.cells.append(cell)
|
||||
|
||||
# Pre-sort clusters
|
||||
# clusters = self.sort_clusters_by_cell_order(clusters)
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells():
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image)
|
||||
for c in clusters:
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
||||
|
||||
cell_color = (
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
)
|
||||
for tc in c.cells: # [:1]:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
||||
image.show()
|
||||
|
||||
# draw_clusters_and_cells()
|
||||
|
||||
clusters, page.cells = self.postprocess(
|
||||
clusters, page.cells, page.size.height
|
||||
)
|
||||
|
||||
# draw_clusters_and_cells()
|
||||
|
||||
page.predictions.layout = LayoutPrediction(clusters=clusters)
|
||||
|
||||
yield page
|
||||
160
docling/models/page_assemble_model.py
Normal file
160
docling/models/page_assemble_model.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterable, List
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
AssembledUnit,
|
||||
FigureElement,
|
||||
Page,
|
||||
PageElement,
|
||||
TableElement,
|
||||
TextElement,
|
||||
)
|
||||
from docling.models.layout_model import LayoutModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PageAssembleModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
# self.line_wrap_pattern = re.compile(r'(?<=[^\W_])- \n(?=\w)')
|
||||
|
||||
# def sanitize_text_poor(self, lines):
|
||||
# text = '\n'.join(lines)
|
||||
#
|
||||
# # treat line wraps.
|
||||
# sanitized_text = self.line_wrap_pattern.sub('', text)
|
||||
#
|
||||
# sanitized_text = sanitized_text.replace('\n', ' ')
|
||||
#
|
||||
# return sanitized_text
|
||||
|
||||
def sanitize_text(self, lines):
|
||||
if len(lines) <= 1:
|
||||
return " ".join(lines)
|
||||
|
||||
for ix, line in enumerate(lines[1:]):
|
||||
prev_line = lines[ix]
|
||||
|
||||
if prev_line.endswith("-"):
|
||||
prev_words = re.findall(r"\b[\w]+\b", prev_line)
|
||||
line_words = re.findall(r"\b[\w]+\b", line)
|
||||
|
||||
if (
|
||||
len(prev_words)
|
||||
and len(line_words)
|
||||
and prev_words[-1].isalnum()
|
||||
and line_words[0].isalnum()
|
||||
):
|
||||
lines[ix] = prev_line[:-1]
|
||||
else:
|
||||
lines[ix] += " "
|
||||
|
||||
sanitized_text = "".join(lines)
|
||||
|
||||
return sanitized_text.strip() # Strip any leading or trailing whitespace
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
# assembles some JSON output page by page.
|
||||
|
||||
elements: List[PageElement] = []
|
||||
headers: List[PageElement] = []
|
||||
body: List[PageElement] = []
|
||||
|
||||
for cluster in page.predictions.layout.clusters:
|
||||
# _log.info("Cluster label seen:", cluster.label)
|
||||
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
|
||||
|
||||
textlines = [
|
||||
cell.text.replace("\x02", "-").strip()
|
||||
for cell in cluster.cells
|
||||
if len(cell.text.strip()) > 0
|
||||
]
|
||||
text = self.sanitize_text(textlines)
|
||||
text_el = TextElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
text=text,
|
||||
page_no=page.page_no,
|
||||
cluster=cluster,
|
||||
)
|
||||
elements.append(text_el)
|
||||
|
||||
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
|
||||
headers.append(text_el)
|
||||
else:
|
||||
body.append(text_el)
|
||||
elif cluster.label == LayoutModel.TABLE_LABEL:
|
||||
tbl = None
|
||||
if page.predictions.tablestructure:
|
||||
tbl = page.predictions.tablestructure.table_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
if (
|
||||
not tbl
|
||||
): # fallback: add table without structure, if it isn't present
|
||||
tbl = TableElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
text="",
|
||||
otsl_seq=[],
|
||||
table_cells=[],
|
||||
cluster=cluster,
|
||||
page_no=page.page_no,
|
||||
)
|
||||
|
||||
elements.append(tbl)
|
||||
body.append(tbl)
|
||||
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
||||
fig = None
|
||||
if page.predictions.figures_classification:
|
||||
fig = page.predictions.figures_classification.figure_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
if (
|
||||
not fig
|
||||
): # fallback: add figure without classification, if it isn't present
|
||||
fig = FigureElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
text="",
|
||||
data=None,
|
||||
cluster=cluster,
|
||||
page_no=page.page_no,
|
||||
)
|
||||
elements.append(fig)
|
||||
body.append(fig)
|
||||
elif cluster.label == LayoutModel.FORMULA_LABEL:
|
||||
equation = None
|
||||
if page.predictions.equations_prediction:
|
||||
equation = (
|
||||
page.predictions.equations_prediction.equation_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
)
|
||||
if not equation: # fallback: add empty formula, if it isn't present
|
||||
text = self.sanitize_text(
|
||||
[
|
||||
cell.text.replace("\x02", "-").strip()
|
||||
for cell in cluster.cells
|
||||
if len(cell.text.strip()) > 0
|
||||
]
|
||||
)
|
||||
equation = TextElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
cluster=cluster,
|
||||
page_no=page.page_no,
|
||||
text=text,
|
||||
)
|
||||
elements.append(equation)
|
||||
body.append(equation)
|
||||
|
||||
page.assembled = AssembledUnit(
|
||||
elements=elements, headers=headers, body=body
|
||||
)
|
||||
|
||||
yield page
|
||||
114
docling/models/table_structure_model.py
Normal file
114
docling/models/table_structure_model.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Page,
|
||||
TableCell,
|
||||
TableElement,
|
||||
TableStructurePrediction,
|
||||
)
|
||||
|
||||
|
||||
class TableStructureModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.do_cell_matching = config["do_cell_matching"]
|
||||
|
||||
self.enabled = config["enabled"]
|
||||
if self.enabled:
|
||||
artifacts_path = config["artifacts_path"]
|
||||
# Third Party
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
|
||||
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
||||
self.tm_config["model"]["save_dir"] = artifacts_path
|
||||
self.tm_model_type = self.tm_config["model"]["type"]
|
||||
|
||||
self.tf_predictor = TFPredictor(self.tm_config)
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
|
||||
if not self.enabled:
|
||||
yield from page_batch
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
||||
|
||||
in_tables = [
|
||||
(
|
||||
cluster,
|
||||
[
|
||||
round(cluster.bbox.l),
|
||||
round(cluster.bbox.t),
|
||||
round(cluster.bbox.r),
|
||||
round(cluster.bbox.b),
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == "Table"
|
||||
]
|
||||
if not len(in_tables):
|
||||
yield page
|
||||
continue
|
||||
|
||||
tokens = []
|
||||
for c in page.cells:
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
tokens.append(c.model_dump())
|
||||
|
||||
iocr_page = {
|
||||
"image": numpy.asarray(page.image),
|
||||
"tokens": tokens,
|
||||
"width": page.size.width,
|
||||
"height": page.size.height,
|
||||
}
|
||||
|
||||
table_clusters, table_bboxes = zip(*in_tables)
|
||||
|
||||
if len(table_bboxes):
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
iocr_page, table_bboxes, do_matching=self.do_cell_matching
|
||||
)
|
||||
|
||||
for table_cluster, table_out in zip(table_clusters, tf_output):
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
|
||||
if not self.do_cell_matching:
|
||||
the_bbox = BoundingBox.model_validate(element["bbox"])
|
||||
text_piece = page._backend.get_text_in_rect(the_bbox)
|
||||
element["bbox"]["token"] = text_piece
|
||||
|
||||
tc = TableCell.model_validate(element)
|
||||
table_cells.append(tc)
|
||||
|
||||
# Retrieving cols/rows, after post processing:
|
||||
num_rows = table_out["predict_details"]["num_rows"]
|
||||
num_cols = table_out["predict_details"]["num_cols"]
|
||||
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
||||
|
||||
tbl = TableElement(
|
||||
otsl_seq=otsl_seq,
|
||||
table_cells=table_cells,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols,
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label="Table",
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
||||
|
||||
yield page
|
||||
Reference in New Issue
Block a user