Docling/docling/models/page_assemble_model.py
Christoph Auer 90875247e5
feat: Establish confidence estimation for document and pages (#1313)
* Establish confidence field, propagate layout confidence through

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add OCR confidence and parse confidence (stub)

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add parse quality rules, use 5% percentile for overall and parse scores

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Heuristic updates

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix garbage regex

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Move grade to page

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Introduce mean_score and low_score, consistent aggregate computations

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add confidence test

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
2025-05-21 12:32:49 +02:00

157 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import re
from collections.abc import Iterable
from typing import List
import numpy as np
from pydantic import BaseModel
from docling.datamodel.base_models import (
AssembledUnit,
ContainerElement,
FigureElement,
Page,
PageElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult
from docling.models.base_model import BasePageModel
from docling.models.layout_model import LayoutModel
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class PageAssembleOptions(BaseModel):
pass
class PageAssembleModel(BasePageModel):
def __init__(self, options: PageAssembleOptions):
self.options = options
def sanitize_text(self, lines):
if len(lines) <= 1:
return " ".join(lines)
for ix, line in enumerate(lines[1:]):
prev_line = lines[ix]
if prev_line.endswith("-"):
prev_words = re.findall(r"\b[\w]+\b", prev_line)
line_words = re.findall(r"\b[\w]+\b", line)
if (
len(prev_words)
and len(line_words)
and prev_words[-1].isalnum()
and line_words[0].isalnum()
):
lines[ix] = prev_line[:-1]
else:
lines[ix] += " "
sanitized_text = "".join(lines)
# Text normalization
sanitized_text = sanitized_text.replace("", "/") # noqa: RUF001
sanitized_text = sanitized_text.replace("", "'") # noqa: RUF001
sanitized_text = sanitized_text.replace("", "'") # noqa: RUF001
sanitized_text = sanitized_text.replace("", '"')
sanitized_text = sanitized_text.replace("", '"')
sanitized_text = sanitized_text.replace("", "·")
return sanitized_text.strip() # Strip any leading or trailing whitespace
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "page_assemble"):
assert page.predictions.layout is not None
# assembles some JSON output page by page.
elements: List[PageElement] = []
headers: List[PageElement] = []
body: List[PageElement] = []
for cluster in page.predictions.layout.clusters:
# _log.info("Cluster label seen:", cluster.label)
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
textlines = [
cell.text.replace("\x02", "-").strip()
for cell in cluster.cells
if len(cell.text.strip()) > 0
]
text = self.sanitize_text(textlines)
text_el = TextElement(
label=cluster.label,
id=cluster.id,
text=text,
page_no=page.page_no,
cluster=cluster,
)
elements.append(text_el)
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
headers.append(text_el)
else:
body.append(text_el)
elif cluster.label in LayoutModel.TABLE_LABELS:
tbl = None
if page.predictions.tablestructure:
tbl = page.predictions.tablestructure.table_map.get(
cluster.id, None
)
if not tbl: # fallback: add table without structure, if it isn't present
tbl = Table(
label=cluster.label,
id=cluster.id,
text="",
otsl_seq=[],
table_cells=[],
cluster=cluster,
page_no=page.page_no,
)
elements.append(tbl)
body.append(tbl)
elif cluster.label == LayoutModel.FIGURE_LABEL:
fig = None
if page.predictions.figures_classification:
fig = page.predictions.figures_classification.figure_map.get(
cluster.id, None
)
if not fig: # fallback: add figure without classification, if it isn't present
fig = FigureElement(
label=cluster.label,
id=cluster.id,
text="",
data=None,
cluster=cluster,
page_no=page.page_no,
)
elements.append(fig)
body.append(fig)
elif cluster.label in LayoutModel.CONTAINER_LABELS:
container_el = ContainerElement(
label=cluster.label,
id=cluster.id,
page_no=page.page_no,
cluster=cluster,
)
elements.append(container_el)
body.append(container_el)
page.assembled = AssembledUnit(
elements=elements, headers=headers, body=body
)
yield page