test: improve typing definitions (part 1) (#72)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-09-12 15:56:29 +02:00 committed by GitHub
parent 53569a1023
commit 8aa476ccd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 91 additions and 29 deletions

View File

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Any, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from PIL import Image
if TYPE_CHECKING:
from docling.datamodel.base_models import BoundingBox, Cell, PageSize
class PdfPageBackend(ABC):
@ -17,12 +20,12 @@ class PdfPageBackend(ABC):
pass
@abstractmethod
def get_bitmap_rects(self, scale: int = 1) -> Iterable["BoundingBox"]:
def get_bitmap_rects(self, float: int = 1) -> Iterable["BoundingBox"]:
pass
@abstractmethod
def get_page_image(
self, scale: int = 1, cropbox: Optional["BoundingBox"] = None
self, scale: float = 1, cropbox: Optional["BoundingBox"] = None
) -> Image.Image:
pass

View File

@ -2,7 +2,7 @@ import logging
import random
from io import BytesIO
from pathlib import Path
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Union
import pypdfium2 as pdfium
from docling_parse.docling_parse import pdf_parser
@ -22,7 +22,6 @@ class DoclingParsePageBackend(PdfPageBackend):
self._ppage = page_obj
parsed_page = parser.parse_pdf_from_key_on_page(document_hash, page_no)
self._dpage = None
self.valid = "pages" in parsed_page
if self.valid:
self._dpage = parsed_page["pages"][0]
@ -68,7 +67,7 @@ class DoclingParsePageBackend(PdfPageBackend):
return text_piece
def get_text_cells(self) -> Iterable[Cell]:
cells = []
cells: List[Cell] = []
cell_counter = 0
if not self.valid:
@ -130,7 +129,7 @@ class DoclingParsePageBackend(PdfPageBackend):
return cells
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32
for i in range(len(self._dpage["images"])):
@ -145,7 +144,7 @@ class DoclingParsePageBackend(PdfPageBackend):
yield cropbox
def get_page_image(
self, scale: int = 1, cropbox: Optional[BoundingBox] = None
self, scale: float = 1, cropbox: Optional[BoundingBox] = None
) -> Image.Image:
page_size = self.get_size()

View File

@ -7,7 +7,7 @@ from typing import Iterable, List, Optional, Union
import pypdfium2 as pdfium
import pypdfium2.raw as pdfium_c
from PIL import Image, ImageDraw
from pypdfium2 import PdfPage
from pypdfium2 import PdfPage, PdfTextPage
from pypdfium2._helpers.misc import PdfiumError
from docling.backend.abstract_backend import PdfDocumentBackend, PdfPageBackend
@ -29,12 +29,12 @@ class PyPdfiumPageBackend(PdfPageBackend):
exc_info=True,
)
self.valid = False
self.text_page = None
self.text_page: Optional[PdfTextPage] = None
def is_valid(self) -> bool:
return self.valid
def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
pos = obj.get_pos()
@ -189,7 +189,7 @@ class PyPdfiumPageBackend(PdfPageBackend):
return cells
def get_page_image(
self, scale: int = 1, cropbox: Optional[BoundingBox] = None
self, scale: float = 1, cropbox: Optional[BoundingBox] = None
) -> Image.Image:
page_size = self.get_size()

View File

@ -87,7 +87,7 @@ class BoundingBox(BaseModel):
return (self.l, self.b, self.r, self.t)
@classmethod
def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin):
def from_tuple(cls, coord: Tuple[float, ...], origin: CoordOrigin):
if origin == CoordOrigin.TOPLEFT:
l, t, r, b = coord[0], coord[1], coord[2], coord[3]
if r < l:
@ -246,7 +246,7 @@ class EquationPrediction(BaseModel):
class PagePredictions(BaseModel):
layout: LayoutPrediction = None
layout: Optional[LayoutPrediction] = None
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
equations_prediction: Optional[EquationPrediction] = None
@ -267,7 +267,7 @@ class Page(BaseModel):
page_no: int
page_hash: Optional[str] = None
size: Optional[PageSize] = None
cells: List[Cell] = None
cells: List[Cell] = []
predictions: PagePredictions = PagePredictions()
assembled: Optional[AssembledUnit] = None

View File

@ -1,12 +1,12 @@
from pathlib import Path
from typing import Iterable
from typing import Callable, Iterable, List
from docling.datamodel.base_models import Page, PipelineOptions
class BaseModelPipeline:
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
self.model_pipe = []
self.model_pipe: List[Callable] = []
self.artifacts_path = artifacts_path
self.pipeline_options = pipeline_options

View File

@ -1,10 +1,10 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Tuple, Union
from docling_core.types.doc.base import BaseCell, Ref, Table, TableCell
from docling_core.types.doc.base import BaseCell, BaseText, Ref, Table, TableCell
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell
from docling.datamodel.document import ConvertedDocument, Page
from docling.datamodel.document import ConversionResult, Page
_log = logging.getLogger(__name__)
@ -15,7 +15,10 @@ def _export_table_to_html(table: Table):
# to the docling-core package.
def _get_tablecell_span(cell: TableCell, ix):
span = set([s[ix] for s in cell.spans])
if cell.spans is None:
span = set()
else:
span = set([s[ix] for s in cell.spans])
if len(span) == 0:
return 1, None, None
return len(span), min(span), max(span)
@ -24,6 +27,8 @@ def _export_table_to_html(table: Table):
nrows = table.num_rows
ncols = table.num_cols
if table.data is None:
return ""
for i in range(nrows):
body += "<tr>"
for j in range(ncols):
@ -66,7 +71,7 @@ def _export_table_to_html(table: Table):
def generate_multimodal_pages(
doc_result: ConvertedDocument,
doc_result: ConversionResult,
) -> Iterable[Tuple[str, str, List[Dict[str, Any]], List[Dict[str, Any]], Page]]:
label_to_doclaynet = {
@ -94,7 +99,7 @@ def generate_multimodal_pages(
page_no = 0
start_ix = 0
end_ix = 0
doc_items = []
doc_items: List[Tuple[int, Union[BaseCell, BaseText]]] = []
doc = doc_result.output
@ -105,11 +110,11 @@ def generate_multimodal_pages(
item_type = item.obj_type
label = label_to_doclaynet.get(item_type, None)
if label is None:
if label is None or item.prov is None or page.size is None:
continue
bbox = BoundingBox.from_tuple(
item.prov[0].bbox, origin=CoordOrigin.BOTTOMLEFT
tuple(item.prov[0].bbox), origin=CoordOrigin.BOTTOMLEFT
)
new_bbox = bbox.to_top_left_origin(page_height=page.size.height).normalized(
page_size=page.size
@ -137,13 +142,15 @@ def generate_multimodal_pages(
return segments
def _process_page_cells(page: Page):
cells = []
cells: List[dict] = []
if page.size is None:
return cells
for cell in page.cells:
new_bbox = cell.bbox.to_top_left_origin(
page_height=page.size.height
).normalized(page_size=page.size)
is_ocr = isinstance(cell, OcrCell)
ocr_confidence = cell.confidence if is_ocr else 1.0
ocr_confidence = cell.confidence if isinstance(cell, OcrCell) else 1.0
cells.append(
{
"text": cell.text,
@ -170,6 +177,8 @@ def generate_multimodal_pages(
return content_text, content_md, content_dt, page_cells, page_segments, page
if doc.main_text is None:
return
for ix, orig_item in enumerate(doc.main_text):
item = doc._resolve_ref(orig_item) if isinstance(orig_item, Ref) else orig_item

33
poetry.lock generated
View File

@ -3771,6 +3771,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]]
name = "pandas-stubs"
version = "2.2.2.240909"
description = "Type annotations for pandas"
optional = false
python-versions = ">=3.10"
files = [
{file = "pandas_stubs-2.2.2.240909-py3-none-any.whl", hash = "sha256:e230f5fa4065f9417804f4d65cd98f86c002efcc07933e8abcd48c3fad9c30a2"},
{file = "pandas_stubs-2.2.2.240909.tar.gz", hash = "sha256:3c0951a2c3e45e3475aed9d80b7147ae82f176b9e42e9fb321cfdebf3d411b3d"},
]
[package.dependencies]
numpy = ">=1.23.5"
types-pytz = ">=2022.1.1"
[[package]]
name = "parso"
version = "0.8.4"
@ -6584,6 +6599,11 @@ files = [
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
]
[package.dependencies]
@ -6617,6 +6637,17 @@ rfc3986 = ">=1.4.0"
tqdm = ">=4.14"
urllib3 = ">=1.26.0"
[[package]]
name = "types-pytz"
version = "2024.1.0.20240417"
description = "Typing stubs for pytz"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-pytz-2024.1.0.20240417.tar.gz", hash = "sha256:6810c8a1f68f21fdf0f4f374a432487c77645a0ac0b31de4bf4690cf21ad3981"},
{file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"},
]
[[package]]
name = "types-requests"
version = "2.32.0.20240907"
@ -7169,4 +7200,4 @@ examples = ["langchain-huggingface", "langchain-milvus", "langchain-text-splitte
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "5ce8fc1e245442e355b967430e211b1378fed2e9fd20d2ddbea47f0e9f1dfcd5"
content-hash = "b881ea7a3504555707e0778c7c25631cbb353b78da04bd724852c7d34f39d46d"

View File

@ -74,6 +74,7 @@ pytest-xdist = "^3.3.1"
types-requests = "^2.31.0.2"
flake8-pyproject = "^1.2.3"
pylint = "^2.17.5"
pandas-stubs = "^2.2.2.240909"
ipykernel = "^6.29.5"
ipywidgets = "^8.1.5"
nbqa = "^1.9.0"
@ -114,6 +115,14 @@ pretty = true
no_implicit_optional = true
python_version = "3.10"
[[tool.mypy.overrides]]
module = [
"docling_parse.*",
"pypdfium2.*",
"networkx.*",
]
ignore_missing_imports = true
[tool.flake8]
max-line-length = 88
extend-ignore = ["E203", "E501"]

View File

@ -45,6 +45,8 @@ def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page]):
def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
assert doc_true.main_text is not None, "doc_true cannot be None"
assert doc_pred.main_text is not None, "doc_true cannot be None"
assert len(doc_true.main_text) == len(
doc_pred.main_text
@ -68,6 +70,13 @@ def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
if doc_true.tables is None:
# No tables to check
assert doc_pred.tables is None, "not expecting any table on this document"
return True
assert doc_pred.tables is not None, "no tables predicted, but expected in doc_true"
assert len(doc_true.tables) == len(
doc_pred.tables
), "document has different count of tables than expected."
@ -82,6 +91,8 @@ def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
true_item.num_cols == pred_item.num_cols
), "table does not have the same #-cols"
assert true_item.data is not None, "documents are expected to have table data"
assert pred_item.data is not None, "documents are expected to have table data"
for i, row in enumerate(true_item.data):
for j, col in enumerate(true_item.data[i]):
@ -135,7 +146,7 @@ def verify_conversion_result(
doc_true_pages = PageList.validate_json(fr.read())
with open(json_path, "r") as fr:
doc_true = DsDocument.model_validate_json(fr.read())
doc_true: DsDocument = DsDocument.model_validate_json(fr.read())
with open(md_path, "r") as fr:
doc_true_md = fr.read()