test: improve typing definitions (part 1) (#72)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-09-12 15:56:29 +02:00 committed by GitHub
parent 53569a1023
commit 8aa476ccd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 91 additions and 29 deletions

View File

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Optional, Union from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from PIL import Image from PIL import Image
if TYPE_CHECKING:
from docling.datamodel.base_models import BoundingBox, Cell, PageSize
class PdfPageBackend(ABC): class PdfPageBackend(ABC):
@ -17,12 +20,12 @@ class PdfPageBackend(ABC):
pass pass
@abstractmethod @abstractmethod
def get_bitmap_rects(self, scale: int = 1) -> Iterable["BoundingBox"]: def get_bitmap_rects(self, float: int = 1) -> Iterable["BoundingBox"]:
pass pass
@abstractmethod @abstractmethod
def get_page_image( def get_page_image(
self, scale: int = 1, cropbox: Optional["BoundingBox"] = None self, scale: float = 1, cropbox: Optional["BoundingBox"] = None
) -> Image.Image: ) -> Image.Image:
pass pass

View File

@ -2,7 +2,7 @@ import logging
import random import random
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Iterable, Optional, Union from typing import Iterable, List, Optional, Union
import pypdfium2 as pdfium import pypdfium2 as pdfium
from docling_parse.docling_parse import pdf_parser from docling_parse.docling_parse import pdf_parser
@ -22,7 +22,6 @@ class DoclingParsePageBackend(PdfPageBackend):
self._ppage = page_obj self._ppage = page_obj
parsed_page = parser.parse_pdf_from_key_on_page(document_hash, page_no) parsed_page = parser.parse_pdf_from_key_on_page(document_hash, page_no)
self._dpage = None
self.valid = "pages" in parsed_page self.valid = "pages" in parsed_page
if self.valid: if self.valid:
self._dpage = parsed_page["pages"][0] self._dpage = parsed_page["pages"][0]
@ -68,7 +67,7 @@ class DoclingParsePageBackend(PdfPageBackend):
return text_piece return text_piece
def get_text_cells(self) -> Iterable[Cell]: def get_text_cells(self) -> Iterable[Cell]:
cells = [] cells: List[Cell] = []
cell_counter = 0 cell_counter = 0
if not self.valid: if not self.valid:
@ -130,7 +129,7 @@ class DoclingParsePageBackend(PdfPageBackend):
return cells return cells
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]: def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32 AREA_THRESHOLD = 32 * 32
for i in range(len(self._dpage["images"])): for i in range(len(self._dpage["images"])):
@ -145,7 +144,7 @@ class DoclingParsePageBackend(PdfPageBackend):
yield cropbox yield cropbox
def get_page_image( def get_page_image(
self, scale: int = 1, cropbox: Optional[BoundingBox] = None self, scale: float = 1, cropbox: Optional[BoundingBox] = None
) -> Image.Image: ) -> Image.Image:
page_size = self.get_size() page_size = self.get_size()

View File

@ -7,7 +7,7 @@ from typing import Iterable, List, Optional, Union
import pypdfium2 as pdfium import pypdfium2 as pdfium
import pypdfium2.raw as pdfium_c import pypdfium2.raw as pdfium_c
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from pypdfium2 import PdfPage from pypdfium2 import PdfPage, PdfTextPage
from pypdfium2._helpers.misc import PdfiumError from pypdfium2._helpers.misc import PdfiumError
from docling.backend.abstract_backend import PdfDocumentBackend, PdfPageBackend from docling.backend.abstract_backend import PdfDocumentBackend, PdfPageBackend
@ -29,12 +29,12 @@ class PyPdfiumPageBackend(PdfPageBackend):
exc_info=True, exc_info=True,
) )
self.valid = False self.valid = False
self.text_page = None self.text_page: Optional[PdfTextPage] = None
def is_valid(self) -> bool: def is_valid(self) -> bool:
return self.valid return self.valid
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]: def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32 AREA_THRESHOLD = 32 * 32
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]): for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
pos = obj.get_pos() pos = obj.get_pos()
@ -189,7 +189,7 @@ class PyPdfiumPageBackend(PdfPageBackend):
return cells return cells
def get_page_image( def get_page_image(
self, scale: int = 1, cropbox: Optional[BoundingBox] = None self, scale: float = 1, cropbox: Optional[BoundingBox] = None
) -> Image.Image: ) -> Image.Image:
page_size = self.get_size() page_size = self.get_size()

View File

@ -87,7 +87,7 @@ class BoundingBox(BaseModel):
return (self.l, self.b, self.r, self.t) return (self.l, self.b, self.r, self.t)
@classmethod @classmethod
def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin): def from_tuple(cls, coord: Tuple[float, ...], origin: CoordOrigin):
if origin == CoordOrigin.TOPLEFT: if origin == CoordOrigin.TOPLEFT:
l, t, r, b = coord[0], coord[1], coord[2], coord[3] l, t, r, b = coord[0], coord[1], coord[2], coord[3]
if r < l: if r < l:
@ -246,7 +246,7 @@ class EquationPrediction(BaseModel):
class PagePredictions(BaseModel): class PagePredictions(BaseModel):
layout: LayoutPrediction = None layout: Optional[LayoutPrediction] = None
tablestructure: Optional[TableStructurePrediction] = None tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None figures_classification: Optional[FigureClassificationPrediction] = None
equations_prediction: Optional[EquationPrediction] = None equations_prediction: Optional[EquationPrediction] = None
@ -267,7 +267,7 @@ class Page(BaseModel):
page_no: int page_no: int
page_hash: Optional[str] = None page_hash: Optional[str] = None
size: Optional[PageSize] = None size: Optional[PageSize] = None
cells: List[Cell] = None cells: List[Cell] = []
predictions: PagePredictions = PagePredictions() predictions: PagePredictions = PagePredictions()
assembled: Optional[AssembledUnit] = None assembled: Optional[AssembledUnit] = None

View File

@ -1,12 +1,12 @@
from pathlib import Path from pathlib import Path
from typing import Iterable from typing import Callable, Iterable, List
from docling.datamodel.base_models import Page, PipelineOptions from docling.datamodel.base_models import Page, PipelineOptions
class BaseModelPipeline: class BaseModelPipeline:
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions): def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
self.model_pipe = [] self.model_pipe: List[Callable] = []
self.artifacts_path = artifacts_path self.artifacts_path = artifacts_path
self.pipeline_options = pipeline_options self.pipeline_options = pipeline_options

View File

@ -1,10 +1,10 @@
import logging import logging
from typing import Any, Dict, Iterable, List, Tuple from typing import Any, Dict, Iterable, List, Tuple, Union
from docling_core.types.doc.base import BaseCell, Ref, Table, TableCell from docling_core.types.doc.base import BaseCell, BaseText, Ref, Table, TableCell
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell
from docling.datamodel.document import ConvertedDocument, Page from docling.datamodel.document import ConversionResult, Page
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -15,6 +15,9 @@ def _export_table_to_html(table: Table):
# to the docling-core package. # to the docling-core package.
def _get_tablecell_span(cell: TableCell, ix): def _get_tablecell_span(cell: TableCell, ix):
if cell.spans is None:
span = set()
else:
span = set([s[ix] for s in cell.spans]) span = set([s[ix] for s in cell.spans])
if len(span) == 0: if len(span) == 0:
return 1, None, None return 1, None, None
@ -24,6 +27,8 @@ def _export_table_to_html(table: Table):
nrows = table.num_rows nrows = table.num_rows
ncols = table.num_cols ncols = table.num_cols
if table.data is None:
return ""
for i in range(nrows): for i in range(nrows):
body += "<tr>" body += "<tr>"
for j in range(ncols): for j in range(ncols):
@ -66,7 +71,7 @@ def _export_table_to_html(table: Table):
def generate_multimodal_pages( def generate_multimodal_pages(
doc_result: ConvertedDocument, doc_result: ConversionResult,
) -> Iterable[Tuple[str, str, List[Dict[str, Any]], List[Dict[str, Any]], Page]]: ) -> Iterable[Tuple[str, str, List[Dict[str, Any]], List[Dict[str, Any]], Page]]:
label_to_doclaynet = { label_to_doclaynet = {
@ -94,7 +99,7 @@ def generate_multimodal_pages(
page_no = 0 page_no = 0
start_ix = 0 start_ix = 0
end_ix = 0 end_ix = 0
doc_items = [] doc_items: List[Tuple[int, Union[BaseCell, BaseText]]] = []
doc = doc_result.output doc = doc_result.output
@ -105,11 +110,11 @@ def generate_multimodal_pages(
item_type = item.obj_type item_type = item.obj_type
label = label_to_doclaynet.get(item_type, None) label = label_to_doclaynet.get(item_type, None)
if label is None: if label is None or item.prov is None or page.size is None:
continue continue
bbox = BoundingBox.from_tuple( bbox = BoundingBox.from_tuple(
item.prov[0].bbox, origin=CoordOrigin.BOTTOMLEFT tuple(item.prov[0].bbox), origin=CoordOrigin.BOTTOMLEFT
) )
new_bbox = bbox.to_top_left_origin(page_height=page.size.height).normalized( new_bbox = bbox.to_top_left_origin(page_height=page.size.height).normalized(
page_size=page.size page_size=page.size
@ -137,13 +142,15 @@ def generate_multimodal_pages(
return segments return segments
def _process_page_cells(page: Page): def _process_page_cells(page: Page):
cells = [] cells: List[dict] = []
if page.size is None:
return cells
for cell in page.cells: for cell in page.cells:
new_bbox = cell.bbox.to_top_left_origin( new_bbox = cell.bbox.to_top_left_origin(
page_height=page.size.height page_height=page.size.height
).normalized(page_size=page.size) ).normalized(page_size=page.size)
is_ocr = isinstance(cell, OcrCell) is_ocr = isinstance(cell, OcrCell)
ocr_confidence = cell.confidence if is_ocr else 1.0 ocr_confidence = cell.confidence if isinstance(cell, OcrCell) else 1.0
cells.append( cells.append(
{ {
"text": cell.text, "text": cell.text,
@ -170,6 +177,8 @@ def generate_multimodal_pages(
return content_text, content_md, content_dt, page_cells, page_segments, page return content_text, content_md, content_dt, page_cells, page_segments, page
if doc.main_text is None:
return
for ix, orig_item in enumerate(doc.main_text): for ix, orig_item in enumerate(doc.main_text):
item = doc._resolve_ref(orig_item) if isinstance(orig_item, Ref) else orig_item item = doc._resolve_ref(orig_item) if isinstance(orig_item, Ref) else orig_item

33
poetry.lock generated
View File

@ -3771,6 +3771,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"] xml = ["lxml (>=4.9.2)"]
[[package]]
name = "pandas-stubs"
version = "2.2.2.240909"
description = "Type annotations for pandas"
optional = false
python-versions = ">=3.10"
files = [
{file = "pandas_stubs-2.2.2.240909-py3-none-any.whl", hash = "sha256:e230f5fa4065f9417804f4d65cd98f86c002efcc07933e8abcd48c3fad9c30a2"},
{file = "pandas_stubs-2.2.2.240909.tar.gz", hash = "sha256:3c0951a2c3e45e3475aed9d80b7147ae82f176b9e42e9fb321cfdebf3d411b3d"},
]
[package.dependencies]
numpy = ">=1.23.5"
types-pytz = ">=2022.1.1"
[[package]] [[package]]
name = "parso" name = "parso"
version = "0.8.4" version = "0.8.4"
@ -6584,6 +6599,11 @@ files = [
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
] ]
[package.dependencies] [package.dependencies]
@ -6617,6 +6637,17 @@ rfc3986 = ">=1.4.0"
tqdm = ">=4.14" tqdm = ">=4.14"
urllib3 = ">=1.26.0" urllib3 = ">=1.26.0"
[[package]]
name = "types-pytz"
version = "2024.1.0.20240417"
description = "Typing stubs for pytz"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-pytz-2024.1.0.20240417.tar.gz", hash = "sha256:6810c8a1f68f21fdf0f4f374a432487c77645a0ac0b31de4bf4690cf21ad3981"},
{file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"},
]
[[package]] [[package]]
name = "types-requests" name = "types-requests"
version = "2.32.0.20240907" version = "2.32.0.20240907"
@ -7169,4 +7200,4 @@ examples = ["langchain-huggingface", "langchain-milvus", "langchain-text-splitte
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "5ce8fc1e245442e355b967430e211b1378fed2e9fd20d2ddbea47f0e9f1dfcd5" content-hash = "b881ea7a3504555707e0778c7c25631cbb353b78da04bd724852c7d34f39d46d"

View File

@ -74,6 +74,7 @@ pytest-xdist = "^3.3.1"
types-requests = "^2.31.0.2" types-requests = "^2.31.0.2"
flake8-pyproject = "^1.2.3" flake8-pyproject = "^1.2.3"
pylint = "^2.17.5" pylint = "^2.17.5"
pandas-stubs = "^2.2.2.240909"
ipykernel = "^6.29.5" ipykernel = "^6.29.5"
ipywidgets = "^8.1.5" ipywidgets = "^8.1.5"
nbqa = "^1.9.0" nbqa = "^1.9.0"
@ -114,6 +115,14 @@ pretty = true
no_implicit_optional = true no_implicit_optional = true
python_version = "3.10" python_version = "3.10"
[[tool.mypy.overrides]]
module = [
"docling_parse.*",
"pypdfium2.*",
"networkx.*",
]
ignore_missing_imports = true
[tool.flake8] [tool.flake8]
max-line-length = 88 max-line-length = 88
extend-ignore = ["E203", "E501"] extend-ignore = ["E203", "E501"]

View File

@ -45,6 +45,8 @@ def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page]):
def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument): def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
assert doc_true.main_text is not None, "doc_true cannot be None"
assert doc_pred.main_text is not None, "doc_true cannot be None"
assert len(doc_true.main_text) == len( assert len(doc_true.main_text) == len(
doc_pred.main_text doc_pred.main_text
@ -68,6 +70,13 @@ def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
def verify_tables(doc_pred: DsDocument, doc_true: DsDocument): def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
if doc_true.tables is None:
# No tables to check
assert doc_pred.tables is None, "not expecting any table on this document"
return True
assert doc_pred.tables is not None, "no tables predicted, but expected in doc_true"
assert len(doc_true.tables) == len( assert len(doc_true.tables) == len(
doc_pred.tables doc_pred.tables
), "document has different count of tables than expected." ), "document has different count of tables than expected."
@ -82,6 +91,8 @@ def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
true_item.num_cols == pred_item.num_cols true_item.num_cols == pred_item.num_cols
), "table does not have the same #-cols" ), "table does not have the same #-cols"
assert true_item.data is not None, "documents are expected to have table data"
assert pred_item.data is not None, "documents are expected to have table data"
for i, row in enumerate(true_item.data): for i, row in enumerate(true_item.data):
for j, col in enumerate(true_item.data[i]): for j, col in enumerate(true_item.data[i]):
@ -135,7 +146,7 @@ def verify_conversion_result(
doc_true_pages = PageList.validate_json(fr.read()) doc_true_pages = PageList.validate_json(fr.read())
with open(json_path, "r") as fr: with open(json_path, "r") as fr:
doc_true = DsDocument.model_validate_json(fr.read()) doc_true: DsDocument = DsDocument.model_validate_json(fr.read())
with open(md_path, "r") as fr: with open(md_path, "r") as fr:
doc_true_md = fr.read() doc_true_md = fr.read()