Initial commit

This commit is contained in:
Christoph Auer
2024-07-15 09:42:42 +02:00
commit e2d996753b
38 changed files with 8767 additions and 0 deletions

View File

View File

@@ -0,0 +1,82 @@
import copy
import random
from deepsearch_glm.nlp_utils import init_nlp_model
from deepsearch_glm.utils.ds_utils import to_legacy_document_format
from deepsearch_glm.utils.load_pretrained_models import load_pretrained_nlp_models
from docling_core.types import BaseText
from docling_core.types import Document as DsDocument
from docling_core.types import Ref
from PIL import ImageDraw
from docling.datamodel.base_models import BoundingBox, Cluster, CoordOrigin
from docling.datamodel.document import ConvertedDocument
class GlmModel:
def __init__(self, config):
self.config = config
load_pretrained_nlp_models()
model = init_nlp_model(model_names="language;term;reference")
self.model = model
def __call__(self, document: ConvertedDocument) -> DsDocument:
ds_doc = document.to_ds_document()
ds_doc_dict = ds_doc.model_dump(by_alias=True)
glm_doc = self.model.apply_on_doc(ds_doc_dict)
ds_doc_dict = to_legacy_document_format(
glm_doc, ds_doc_dict, update_name_label=True
)
exported_doc = DsDocument.model_validate(ds_doc_dict)
# DEBUG code:
def draw_clusters_and_cells(ds_document, page_no):
clusters_to_draw = []
image = copy.deepcopy(document.pages[page_no].image)
for ix, elem in enumerate(ds_document.main_text):
if isinstance(elem, BaseText):
prov = elem.prov[0]
elif isinstance(elem, Ref):
_, arr, index = elem.ref.split("/")
index = int(index)
if arr == "tables":
prov = ds_document.tables[index].prov[0]
elif arr == "figures":
prov = ds_document.figures[index].prov[0]
else:
prov = None
if prov and prov.page == page_no:
clusters_to_draw.append(
Cluster(
id=ix,
label=elem.name,
bbox=BoundingBox.from_tuple(
coord=prov.bbox,
origin=CoordOrigin.BOTTOMLEFT,
).to_top_left_origin(document.pages[page_no].size.height),
)
)
draw = ImageDraw.Draw(image)
for c in clusters_to_draw:
x0, y0, x1, y1 = c.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
draw.text((x0 + 2, y0 + 2), f"{c.id}:{c.label}", fill=(255, 0, 0, 255))
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in c.cells: # [:1]:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
image.show()
# draw_clusters_and_cells(ds_doc, 0)
# draw_clusters_and_cells(exported_doc, 0)
return exported_doc

View File

@@ -0,0 +1,77 @@
import copy
import logging
import random
from typing import Iterable
import numpy
from PIL import ImageDraw
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
_log = logging.getLogger(__name__)
class EasyOcrModel:
def __init__(self, config):
self.config = config
self.enabled = config["enabled"]
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
if self.enabled:
import easyocr
self.reader = easyocr.Reader(config["lang"])
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
for page in page_batch:
# rects = page._fpage.
high_res_image = page._backend.get_page_image(scale=self.scale)
im = numpy.array(high_res_image)
result = self.reader.readtext(im)
del high_res_image
del im
cells = [
OcrCell(
id=ix,
text=line[1],
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
line[0][0][0] / self.scale,
line[0][0][1] / self.scale,
line[0][2][0] / self.scale,
line[0][2][1] / self.scale,
),
origin=CoordOrigin.TOPLEFT,
),
)
for ix, line in enumerate(result)
]
page.cells = cells # For now, just overwrites all digital cells.
# DEBUG code:
def draw_clusters_and_cells():
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image)
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
image.show()
# draw_clusters_and_cells()
yield page

View File

@@ -0,0 +1,318 @@
import copy
import logging
import random
import time
from typing import Iterable, List
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import ImageDraw
from docling.datamodel.base_models import (
BoundingBox,
Cell,
Cluster,
CoordOrigin,
LayoutPrediction,
Page,
)
from docling.utils import layout_utils as lu
_log = logging.getLogger(__name__)
class LayoutModel:
TEXT_ELEM_LABELS = [
"Text",
"Footnote",
"Caption",
"Checkbox-Unselected",
"Checkbox-Selected",
"Section-header",
"Page-header",
"Page-footer",
"Code",
"List-item",
# "Formula",
]
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"]
TABLE_LABEL = "Table"
FIGURE_LABEL = "Picture"
FORMULA_LABEL = "Formula"
def __init__(self, config):
self.config = config
self.layout_predictor = LayoutPredictor(
config["artifacts_path"]
) # TODO temporary
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
MIN_INTERSECTION = 0.2
CLASS_THRESHOLDS = {
"Caption": 0.35,
"Footnote": 0.35,
"Formula": 0.35,
"List-item": 0.35,
"Page-footer": 0.35,
"Page-header": 0.35,
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples.
"Section-header": 0.45,
"Table": 0.35,
"Text": 0.45,
"Title": 0.45,
"Document Index": 0.45,
"Code": 0.45,
"Checkbox-Selected": 0.45,
"Checkbox-Unselected": 0.45,
"Form": 0.45,
"Key-Value Region": 0.45,
}
_log.debug("================= Start postprocess function ====================")
start_time = time.time()
# Apply Confidence Threshold to cluster predictions
# confidence = self.conf_threshold
clusters_out = []
for cluster in clusters:
confidence = CLASS_THRESHOLDS[cluster.label]
if cluster.confidence >= confidence:
# annotation["created_by"] = "high_conf_pred"
clusters_out.append(cluster)
# map to dictionary clusters and cells, with bottom left origin
clusters = [
{
"id": c.id,
"bbox": list(
c.bbox.to_bottom_left_origin(page_height).as_tuple()
), # TODO
"confidence": c.confidence,
"cell_ids": [],
"type": c.label,
}
for c in clusters
]
clusters_out = [
{
"id": c.id,
"bbox": list(
c.bbox.to_bottom_left_origin(page_height).as_tuple()
), # TODO
"confidence": c.confidence,
"created_by": "high_conf_pred",
"cell_ids": [],
"type": c.label,
}
for c in clusters_out
]
raw_cells = [
{
"id": c.id,
"bbox": list(
c.bbox.to_bottom_left_origin(page_height).as_tuple()
), # TODO
"text": c.text,
}
for c in cells
]
cell_count = len(raw_cells)
_log.debug("---- 0. Treat cluster overlaps ------")
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
_log.debug(
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
)
## Check for cells included in or touched by clusters:
clusters_out = lu.assigning_cell_ids_to_clusters(
clusters_out, raw_cells, MIN_INTERSECTION
)
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
# Creates a map of cell_id->cluster_id
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
# Assign orphan cells with lower confidence predictions
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
clusters_out, clusters, raw_cells, orphan_cell_indices
)
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
clusters_out = lu.assigning_cell_ids_to_clusters(
clusters_out, raw_cells, MIN_INTERSECTION
)
_log.debug("---- 3. Settle Ambigous Cells")
# Creates an update map after assignment of cell_id->cluster_id
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
# Settle pdf cells that belong to multiple clusters
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
clusters_out, raw_cells, ambiguous_cell_indices
)
_log.debug("---- 4. Set Orphans as Text")
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
clusters_out, clusters, raw_cells, orphan_cell_indices
)
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
# Merge cells orphan cells
clusters_out = lu.merge_cells(clusters_out)
# Clean up clusters that remain from merged and unreasonable clusters
clusters_out = lu.clean_up_clusters(
clusters_out,
raw_cells,
merge_cells=True,
img_table=True,
one_cell_table=True,
)
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
clusters_out = new_clusters
## We first rebuild where every cell is now:
## Now we write into a prediction cells list, not into the raw cells list.
## As we don't need previous labels, we best overwrite any old list, because that might
## have been sorted differently.
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
target_cells = []
for ix, cell in enumerate(raw_cells):
new_cell = {
"id": ix,
"rawcell_id": ix,
"label": "None",
"bbox": cell["bbox"],
"text": cell["text"],
}
for cluster_index in clusters_around_cells[
ix
]: # By previous analysis, this is always 1 cluster.
new_cell["label"] = clusters_out[cluster_index]["type"]
target_cells.append(new_cell)
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
cells_out = target_cells
## -------------------------------
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
_log.debug("---- 5. Sort clusters in reading order ------")
sorted_clusters = lu.produce_reading_order(
clusters_out, "raw_cell_ids", "raw_cell_ids", True
)
clusters_out = sorted_clusters
# end_time = timer()
_log.debug("---- End of postprocessing function ------")
end_time = time.time() - start_time
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
cells_out = [
Cell(
id=c["id"],
bbox=BoundingBox.from_tuple(
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_height),
text=c["text"],
)
for c in cells_out
]
clusters_out_new = []
for c in clusters_out:
cluster_cells = [ccell for ccell in cells_out if ccell.id in c["cell_ids"]]
c_new = Cluster(
id=c["id"],
bbox=BoundingBox.from_tuple(
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_height),
confidence=c["confidence"],
label=c["type"],
cells=cluster_cells,
)
clusters_out_new.append(c_new)
return clusters_out_new, cells_out
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
clusters = []
for ix, pred_item in enumerate(self.layout_predictor.predict(page.image)):
cluster = Cluster(
id=ix,
label=pred_item["label"],
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
cells=[],
)
clusters.append(cluster)
# Map cells to clusters
# TODO: Remove, postprocess should take care of it anyway.
for cell in page.cells:
for cluster in clusters:
if not cell.bbox.area() > 0:
overlap_frac = 0.0
else:
overlap_frac = (
cell.bbox.intersection_area_with(cluster.bbox)
/ cell.bbox.area()
)
if overlap_frac > 0.5:
cluster.cells.append(cell)
# Pre-sort clusters
# clusters = self.sort_clusters_by_cell_order(clusters)
# DEBUG code:
def draw_clusters_and_cells():
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image)
for c in clusters:
x0, y0, x1, y1 = c.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in c.cells: # [:1]:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
image.show()
# draw_clusters_and_cells()
clusters, page.cells = self.postprocess(
clusters, page.cells, page.size.height
)
# draw_clusters_and_cells()
page.predictions.layout = LayoutPrediction(clusters=clusters)
yield page

View File

@@ -0,0 +1,160 @@
import logging
import re
from typing import Iterable, List
from docling.datamodel.base_models import (
AssembledUnit,
FigureElement,
Page,
PageElement,
TableElement,
TextElement,
)
from docling.models.layout_model import LayoutModel
_log = logging.getLogger(__name__)
class PageAssembleModel:
def __init__(self, config):
self.config = config
# self.line_wrap_pattern = re.compile(r'(?<=[^\W_])- \n(?=\w)')
# def sanitize_text_poor(self, lines):
# text = '\n'.join(lines)
#
# # treat line wraps.
# sanitized_text = self.line_wrap_pattern.sub('', text)
#
# sanitized_text = sanitized_text.replace('\n', ' ')
#
# return sanitized_text
def sanitize_text(self, lines):
if len(lines) <= 1:
return " ".join(lines)
for ix, line in enumerate(lines[1:]):
prev_line = lines[ix]
if prev_line.endswith("-"):
prev_words = re.findall(r"\b[\w]+\b", prev_line)
line_words = re.findall(r"\b[\w]+\b", line)
if (
len(prev_words)
and len(line_words)
and prev_words[-1].isalnum()
and line_words[0].isalnum()
):
lines[ix] = prev_line[:-1]
else:
lines[ix] += " "
sanitized_text = "".join(lines)
return sanitized_text.strip() # Strip any leading or trailing whitespace
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
# assembles some JSON output page by page.
elements: List[PageElement] = []
headers: List[PageElement] = []
body: List[PageElement] = []
for cluster in page.predictions.layout.clusters:
# _log.info("Cluster label seen:", cluster.label)
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
textlines = [
cell.text.replace("\x02", "-").strip()
for cell in cluster.cells
if len(cell.text.strip()) > 0
]
text = self.sanitize_text(textlines)
text_el = TextElement(
label=cluster.label,
id=cluster.id,
text=text,
page_no=page.page_no,
cluster=cluster,
)
elements.append(text_el)
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
headers.append(text_el)
else:
body.append(text_el)
elif cluster.label == LayoutModel.TABLE_LABEL:
tbl = None
if page.predictions.tablestructure:
tbl = page.predictions.tablestructure.table_map.get(
cluster.id, None
)
if (
not tbl
): # fallback: add table without structure, if it isn't present
tbl = TableElement(
label=cluster.label,
id=cluster.id,
text="",
otsl_seq=[],
table_cells=[],
cluster=cluster,
page_no=page.page_no,
)
elements.append(tbl)
body.append(tbl)
elif cluster.label == LayoutModel.FIGURE_LABEL:
fig = None
if page.predictions.figures_classification:
fig = page.predictions.figures_classification.figure_map.get(
cluster.id, None
)
if (
not fig
): # fallback: add figure without classification, if it isn't present
fig = FigureElement(
label=cluster.label,
id=cluster.id,
text="",
data=None,
cluster=cluster,
page_no=page.page_no,
)
elements.append(fig)
body.append(fig)
elif cluster.label == LayoutModel.FORMULA_LABEL:
equation = None
if page.predictions.equations_prediction:
equation = (
page.predictions.equations_prediction.equation_map.get(
cluster.id, None
)
)
if not equation: # fallback: add empty formula, if it isn't present
text = self.sanitize_text(
[
cell.text.replace("\x02", "-").strip()
for cell in cluster.cells
if len(cell.text.strip()) > 0
]
)
equation = TextElement(
label=cluster.label,
id=cluster.id,
cluster=cluster,
page_no=page.page_no,
text=text,
)
elements.append(equation)
body.append(equation)
page.assembled = AssembledUnit(
elements=elements, headers=headers, body=body
)
yield page

View File

@@ -0,0 +1,114 @@
from typing import Iterable
import numpy
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from docling.datamodel.base_models import (
BoundingBox,
Page,
TableCell,
TableElement,
TableStructurePrediction,
)
class TableStructureModel:
def __init__(self, config):
self.config = config
self.do_cell_matching = config["do_cell_matching"]
self.enabled = config["enabled"]
if self.enabled:
artifacts_path = config["artifacts_path"]
# Third Party
import docling_ibm_models.tableformer.common as c
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
self.tm_config["model"]["save_dir"] = artifacts_path
self.tm_model_type = self.tm_config["model"]["type"]
self.tf_predictor = TFPredictor(self.tm_config)
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
for page in page_batch:
page.predictions.tablestructure = TableStructurePrediction() # dummy
in_tables = [
(
cluster,
[
round(cluster.bbox.l),
round(cluster.bbox.t),
round(cluster.bbox.r),
round(cluster.bbox.b),
],
)
for cluster in page.predictions.layout.clusters
if cluster.label == "Table"
]
if not len(in_tables):
yield page
continue
tokens = []
for c in page.cells:
for cluster, _ in in_tables:
if c.bbox.area() > 0:
if (
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
> 0.2
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
tokens.append(c.model_dump())
iocr_page = {
"image": numpy.asarray(page.image),
"tokens": tokens,
"width": page.size.width,
"height": page.size.height,
}
table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict(
iocr_page, table_bboxes, do_matching=self.do_cell_matching
)
for table_cluster, table_out in zip(table_clusters, tf_output):
table_cells = []
for element in table_out["tf_responses"]:
if not self.do_cell_matching:
the_bbox = BoundingBox.model_validate(element["bbox"])
text_piece = page._backend.get_text_in_rect(the_bbox)
element["bbox"]["token"] = text_piece
tc = TableCell.model_validate(element)
table_cells.append(tc)
# Retrieving cols/rows, after post processing:
num_rows = table_out["predict_details"]["num_rows"]
num_cols = table_out["predict_details"]["num_cols"]
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
tbl = TableElement(
otsl_seq=otsl_seq,
table_cells=table_cells,
num_rows=num_rows,
num_cols=num_cols,
id=table_cluster.id,
page_no=page.page_no,
cluster=table_cluster,
label="Table",
)
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
yield page