
* add coverage calculation and push Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * new codecov version and usage of token Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * enable ruff formatter instead of black and isort Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * apply ruff lint fixes Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * apply ruff unsafe fixes Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add removed imports Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * runs 1 on linter issues Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * finalize linter fixes Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * Update pyproject.toml Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com> Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com>
156 lines
6.2 KiB
Python
156 lines
6.2 KiB
Python
import logging
|
||
import re
|
||
from collections.abc import Iterable
|
||
from typing import List
|
||
|
||
from pydantic import BaseModel
|
||
|
||
from docling.datamodel.base_models import (
|
||
AssembledUnit,
|
||
ContainerElement,
|
||
FigureElement,
|
||
Page,
|
||
PageElement,
|
||
Table,
|
||
TextElement,
|
||
)
|
||
from docling.datamodel.document import ConversionResult
|
||
from docling.models.base_model import BasePageModel
|
||
from docling.models.layout_model import LayoutModel
|
||
from docling.utils.profiling import TimeRecorder
|
||
|
||
_log = logging.getLogger(__name__)
|
||
|
||
|
||
class PageAssembleOptions(BaseModel):
|
||
pass
|
||
|
||
|
||
class PageAssembleModel(BasePageModel):
|
||
def __init__(self, options: PageAssembleOptions):
|
||
self.options = options
|
||
|
||
def sanitize_text(self, lines):
|
||
if len(lines) <= 1:
|
||
return " ".join(lines)
|
||
|
||
for ix, line in enumerate(lines[1:]):
|
||
prev_line = lines[ix]
|
||
|
||
if prev_line.endswith("-"):
|
||
prev_words = re.findall(r"\b[\w]+\b", prev_line)
|
||
line_words = re.findall(r"\b[\w]+\b", line)
|
||
|
||
if (
|
||
len(prev_words)
|
||
and len(line_words)
|
||
and prev_words[-1].isalnum()
|
||
and line_words[0].isalnum()
|
||
):
|
||
lines[ix] = prev_line[:-1]
|
||
else:
|
||
lines[ix] += " "
|
||
|
||
sanitized_text = "".join(lines)
|
||
|
||
# Text normalization
|
||
sanitized_text = sanitized_text.replace("⁄", "/") # noqa: RUF001
|
||
sanitized_text = sanitized_text.replace("’", "'") # noqa: RUF001
|
||
sanitized_text = sanitized_text.replace("‘", "'") # noqa: RUF001
|
||
sanitized_text = sanitized_text.replace("“", '"')
|
||
sanitized_text = sanitized_text.replace("”", '"')
|
||
sanitized_text = sanitized_text.replace("•", "·")
|
||
|
||
return sanitized_text.strip() # Strip any leading or trailing whitespace
|
||
|
||
def __call__(
|
||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||
) -> Iterable[Page]:
|
||
for page in page_batch:
|
||
assert page._backend is not None
|
||
if not page._backend.is_valid():
|
||
yield page
|
||
else:
|
||
with TimeRecorder(conv_res, "page_assemble"):
|
||
assert page.predictions.layout is not None
|
||
|
||
# assembles some JSON output page by page.
|
||
|
||
elements: List[PageElement] = []
|
||
headers: List[PageElement] = []
|
||
body: List[PageElement] = []
|
||
|
||
for cluster in page.predictions.layout.clusters:
|
||
# _log.info("Cluster label seen:", cluster.label)
|
||
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
|
||
textlines = [
|
||
cell.text.replace("\x02", "-").strip()
|
||
for cell in cluster.cells
|
||
if len(cell.text.strip()) > 0
|
||
]
|
||
text = self.sanitize_text(textlines)
|
||
text_el = TextElement(
|
||
label=cluster.label,
|
||
id=cluster.id,
|
||
text=text,
|
||
page_no=page.page_no,
|
||
cluster=cluster,
|
||
)
|
||
elements.append(text_el)
|
||
|
||
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
|
||
headers.append(text_el)
|
||
else:
|
||
body.append(text_el)
|
||
elif cluster.label in LayoutModel.TABLE_LABELS:
|
||
tbl = None
|
||
if page.predictions.tablestructure:
|
||
tbl = page.predictions.tablestructure.table_map.get(
|
||
cluster.id, None
|
||
)
|
||
if not tbl: # fallback: add table without structure, if it isn't present
|
||
tbl = Table(
|
||
label=cluster.label,
|
||
id=cluster.id,
|
||
text="",
|
||
otsl_seq=[],
|
||
table_cells=[],
|
||
cluster=cluster,
|
||
page_no=page.page_no,
|
||
)
|
||
|
||
elements.append(tbl)
|
||
body.append(tbl)
|
||
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
||
fig = None
|
||
if page.predictions.figures_classification:
|
||
fig = page.predictions.figures_classification.figure_map.get(
|
||
cluster.id, None
|
||
)
|
||
if not fig: # fallback: add figure without classification, if it isn't present
|
||
fig = FigureElement(
|
||
label=cluster.label,
|
||
id=cluster.id,
|
||
text="",
|
||
data=None,
|
||
cluster=cluster,
|
||
page_no=page.page_no,
|
||
)
|
||
elements.append(fig)
|
||
body.append(fig)
|
||
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
||
container_el = ContainerElement(
|
||
label=cluster.label,
|
||
id=cluster.id,
|
||
page_no=page.page_no,
|
||
cluster=cluster,
|
||
)
|
||
elements.append(container_el)
|
||
body.append(container_el)
|
||
|
||
page.assembled = AssembledUnit(
|
||
elements=elements, headers=headers, body=body
|
||
)
|
||
|
||
yield page
|