feat!: Docling v2 (#117)
--------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Maxim Lysak <mly@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Co-authored-by: Maxim Lysak <mly@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,10 @@ import copy
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import ImageDraw
|
||||
|
||||
@@ -11,74 +13,73 @@ from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Cell,
|
||||
Cluster,
|
||||
CoordOrigin,
|
||||
LayoutPrediction,
|
||||
Page,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import layout_utils as lu
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayoutModel:
|
||||
class LayoutModel(BasePageModel):
|
||||
|
||||
TEXT_ELEM_LABELS = [
|
||||
"Text",
|
||||
"Footnote",
|
||||
"Caption",
|
||||
"Checkbox-Unselected",
|
||||
"Checkbox-Selected",
|
||||
"Section-header",
|
||||
"Page-header",
|
||||
"Page-footer",
|
||||
"Code",
|
||||
"List-item",
|
||||
# "Title"
|
||||
DocItemLabel.TEXT,
|
||||
DocItemLabel.FOOTNOTE,
|
||||
DocItemLabel.CAPTION,
|
||||
DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
DocItemLabel.CHECKBOX_SELECTED,
|
||||
DocItemLabel.SECTION_HEADER,
|
||||
DocItemLabel.PAGE_HEADER,
|
||||
DocItemLabel.PAGE_FOOTER,
|
||||
DocItemLabel.CODE,
|
||||
DocItemLabel.LIST_ITEM,
|
||||
# "Formula",
|
||||
]
|
||||
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"]
|
||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||
|
||||
TABLE_LABEL = "Table"
|
||||
FIGURE_LABEL = "Picture"
|
||||
FORMULA_LABEL = "Formula"
|
||||
TABLE_LABEL = DocItemLabel.TABLE
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
config["artifacts_path"]
|
||||
) # TODO temporary
|
||||
def __init__(self, artifacts_path: Path):
|
||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||
|
||||
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
|
||||
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
CLASS_THRESHOLDS = {
|
||||
"Caption": 0.35,
|
||||
"Footnote": 0.35,
|
||||
"Formula": 0.35,
|
||||
"List-item": 0.35,
|
||||
"Page-footer": 0.35,
|
||||
"Page-header": 0.35,
|
||||
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples.
|
||||
"Section-header": 0.45,
|
||||
"Table": 0.35,
|
||||
"Text": 0.45,
|
||||
"Title": 0.45,
|
||||
"Document Index": 0.45,
|
||||
"Code": 0.45,
|
||||
"Checkbox-Selected": 0.45,
|
||||
"Checkbox-Unselected": 0.45,
|
||||
"Form": 0.45,
|
||||
"Key-Value Region": 0.45,
|
||||
DocItemLabel.CAPTION: 0.35,
|
||||
DocItemLabel.FOOTNOTE: 0.35,
|
||||
DocItemLabel.FORMULA: 0.35,
|
||||
DocItemLabel.LIST_ITEM: 0.35,
|
||||
DocItemLabel.PAGE_FOOTER: 0.35,
|
||||
DocItemLabel.PAGE_HEADER: 0.35,
|
||||
DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
|
||||
DocItemLabel.SECTION_HEADER: 0.45,
|
||||
DocItemLabel.TABLE: 0.35,
|
||||
DocItemLabel.TEXT: 0.45,
|
||||
DocItemLabel.TITLE: 0.45,
|
||||
DocItemLabel.DOCUMENT_INDEX: 0.45,
|
||||
DocItemLabel.CODE: 0.45,
|
||||
DocItemLabel.CHECKBOX_SELECTED: 0.45,
|
||||
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
|
||||
DocItemLabel.FORM: 0.45,
|
||||
DocItemLabel.KEY_VALUE_REGION: 0.45,
|
||||
}
|
||||
|
||||
CLASS_REMAPPINGS = {"Document Index": "Table", "Title": "Section-header"}
|
||||
CLASS_REMAPPINGS = {
|
||||
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||
}
|
||||
|
||||
_log.debug("================= Start postprocess function ====================")
|
||||
start_time = time.time()
|
||||
# Apply Confidence Threshold to cluster predictions
|
||||
# confidence = self.conf_threshold
|
||||
clusters_out = []
|
||||
clusters_mod = []
|
||||
|
||||
for cluster in clusters:
|
||||
for cluster in clusters_in:
|
||||
confidence = CLASS_THRESHOLDS[cluster.label]
|
||||
if cluster.confidence >= confidence:
|
||||
# annotation["created_by"] = "high_conf_pred"
|
||||
@@ -86,10 +87,10 @@ class LayoutModel:
|
||||
# Remap class labels where needed.
|
||||
if cluster.label in CLASS_REMAPPINGS.keys():
|
||||
cluster.label = CLASS_REMAPPINGS[cluster.label]
|
||||
clusters_out.append(cluster)
|
||||
clusters_mod.append(cluster)
|
||||
|
||||
# map to dictionary clusters and cells, with bottom left origin
|
||||
clusters = [
|
||||
clusters_orig = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
@@ -99,7 +100,7 @@ class LayoutModel:
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters
|
||||
for c in clusters_in
|
||||
]
|
||||
|
||||
clusters_out = [
|
||||
@@ -113,9 +114,11 @@ class LayoutModel:
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters_out
|
||||
for c in clusters_mod
|
||||
]
|
||||
|
||||
del clusters_mod
|
||||
|
||||
raw_cells = [
|
||||
{
|
||||
"id": c.id,
|
||||
@@ -149,7 +152,7 @@ class LayoutModel:
|
||||
|
||||
# Assign orphan cells with lower confidence predictions
|
||||
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
||||
clusters_out, clusters, raw_cells, orphan_cell_indices
|
||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
||||
@@ -178,7 +181,7 @@ class LayoutModel:
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
||||
clusters_out, clusters, raw_cells, orphan_cell_indices
|
||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
||||
@@ -237,46 +240,55 @@ class LayoutModel:
|
||||
end_time = time.time() - start_time
|
||||
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
||||
|
||||
cells_out = [
|
||||
cells_out_new = [
|
||||
Cell(
|
||||
id=c["id"],
|
||||
id=c["id"], # type: ignore
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
||||
).to_top_left_origin(page_height),
|
||||
text=c["text"],
|
||||
text=c["text"], # type: ignore
|
||||
)
|
||||
for c in cells_out
|
||||
]
|
||||
|
||||
del cells_out
|
||||
|
||||
clusters_out_new = []
|
||||
for c in clusters_out:
|
||||
cluster_cells = [ccell for ccell in cells_out if ccell.id in c["cell_ids"]]
|
||||
cluster_cells = [
|
||||
ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore
|
||||
]
|
||||
c_new = Cluster(
|
||||
id=c["id"],
|
||||
id=c["id"], # type: ignore
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
||||
).to_top_left_origin(page_height),
|
||||
confidence=c["confidence"],
|
||||
label=c["type"],
|
||||
confidence=c["confidence"], # type: ignore
|
||||
label=DocItemLabel(c["type"]),
|
||||
cells=cluster_cells,
|
||||
)
|
||||
clusters_out_new.append(c_new)
|
||||
|
||||
return clusters_out_new, cells_out
|
||||
return clusters_out_new, cells_out_new
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page.size is not None
|
||||
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(
|
||||
self.layout_predictor.predict(page.get_image(scale=1.0))
|
||||
):
|
||||
label = DocItemLabel(
|
||||
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
|
||||
) # Temporary, until docling-ibm-model uses docling-core types
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
label=pred_item["label"],
|
||||
label=label,
|
||||
confidence=pred_item["confidence"],
|
||||
bbox=BoundingBox.model_validate(pred_item),
|
||||
cells=[],
|
||||
)
|
||||
|
||||
clusters.append(cluster)
|
||||
|
||||
# Map cells to clusters
|
||||
|
||||
Reference in New Issue
Block a user