refactor: allow the usage of backends in the enrich models and generalize the interface (#742)
* fix get image with cropbox Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * allow the usage of backends in the enrich models and generalize the interface Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move logic in BaseTextImageEnrichmentModel Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * renaming Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
f7e1cbf629
commit
57fc28d3d8
@ -163,7 +163,7 @@ class DoclingParsePageBackend(PdfPageBackend):
|
||||
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
|
||||
)
|
||||
else:
|
||||
padbox = cropbox.to_bottom_left_origin(page_size.height)
|
||||
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
|
||||
padbox.r = page_size.width - padbox.r
|
||||
padbox.t = page_size.height - padbox.t
|
||||
|
||||
|
@ -178,7 +178,7 @@ class DoclingParseV2PageBackend(PdfPageBackend):
|
||||
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
|
||||
)
|
||||
else:
|
||||
padbox = cropbox.to_bottom_left_origin(page_size.height)
|
||||
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
|
||||
padbox.r = page_size.width - padbox.r
|
||||
padbox.t = page_size.height - padbox.t
|
||||
|
||||
|
@ -210,7 +210,7 @@ class PyPdfiumPageBackend(PdfPageBackend):
|
||||
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
|
||||
)
|
||||
else:
|
||||
padbox = cropbox.to_bottom_left_origin(page_size.height)
|
||||
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
|
||||
padbox.r = page_size.width - padbox.r
|
||||
padbox.t = page_size.height - padbox.t
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from docling_core.types.doc import (
|
||||
BoundingBox,
|
||||
DocItemLabel,
|
||||
NodeItem,
|
||||
PictureDataType,
|
||||
Size,
|
||||
TableCell,
|
||||
@ -201,6 +202,13 @@ class AssembledUnit(BaseModel):
|
||||
headers: List[PageElement] = []
|
||||
|
||||
|
||||
class ItemAndImageEnrichmentElement(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
item: NodeItem
|
||||
image: Image
|
||||
|
||||
|
||||
class Page(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@ -219,12 +227,28 @@ class Page(BaseModel):
|
||||
{}
|
||||
) # Cache of images in different scales. By default it is cleared during assembling.
|
||||
|
||||
def get_image(self, scale: float = 1.0) -> Optional[Image]:
|
||||
def get_image(
|
||||
self, scale: float = 1.0, cropbox: Optional[BoundingBox] = None
|
||||
) -> Optional[Image]:
|
||||
if self._backend is None:
|
||||
return self._image_cache.get(scale, None)
|
||||
|
||||
if not scale in self._image_cache:
|
||||
self._image_cache[scale] = self._backend.get_page_image(scale=scale)
|
||||
return self._image_cache[scale]
|
||||
if cropbox is None:
|
||||
self._image_cache[scale] = self._backend.get_page_image(scale=scale)
|
||||
else:
|
||||
return self._backend.get_page_image(scale=scale, cropbox=cropbox)
|
||||
|
||||
if cropbox is None:
|
||||
return self._image_cache[scale]
|
||||
else:
|
||||
page_im = self._image_cache[scale]
|
||||
assert self.size is not None
|
||||
return page_im.crop(
|
||||
cropbox.to_top_left_origin(page_height=self.size.height)
|
||||
.scaled(scale=scale)
|
||||
.as_tuple()
|
||||
)
|
||||
|
||||
@property
|
||||
def image(self) -> Optional[Image]:
|
||||
|
@ -1,9 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Generic, Iterable, Optional
|
||||
|
||||
from docling_core.types.doc import DoclingDocument, NodeItem
|
||||
from docling_core.types.doc import DoclingDocument, NodeItem, TextItem
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
|
||||
|
||||
@ -15,14 +16,54 @@ class BasePageModel(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnrichmentModel(ABC):
|
||||
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||
|
||||
|
||||
class GenericEnrichmentModel(ABC, Generic[EnrichElementT]):
|
||||
|
||||
@abstractmethod
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
|
||||
) -> Iterable[Any]:
|
||||
def prepare_element(
|
||||
self, conv_res: ConversionResult, element: NodeItem
|
||||
) -> Optional[EnrichElementT]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, doc: DoclingDocument, element_batch: Iterable[EnrichElementT]
|
||||
) -> Iterable[NodeItem]:
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnrichmentModel(GenericEnrichmentModel[NodeItem]):
|
||||
|
||||
def prepare_element(
|
||||
self, conv_res: ConversionResult, element: NodeItem
|
||||
) -> Optional[NodeItem]:
|
||||
if self.is_processable(doc=conv_res.document, element=element):
|
||||
return element
|
||||
return None
|
||||
|
||||
|
||||
class BaseItemAndImageEnrichmentModel(
|
||||
GenericEnrichmentModel[ItemAndImageEnrichmentElement]
|
||||
):
|
||||
|
||||
images_scale: float
|
||||
|
||||
def prepare_element(
|
||||
self, conv_res: ConversionResult, element: NodeItem
|
||||
) -> Optional[ItemAndImageEnrichmentElement]:
|
||||
if not self.is_processable(doc=conv_res.document, element=element):
|
||||
return None
|
||||
|
||||
assert isinstance(element, TextItem)
|
||||
element_prov = element.prov[0]
|
||||
page_ix = element_prov.page_no - 1
|
||||
cropped_image = conv_res.pages[page_ix].get_image(
|
||||
scale=self.images_scale, cropbox=element_prov.bbox
|
||||
)
|
||||
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)
|
||||
|
@ -22,7 +22,7 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PageAssembleOptions(BaseModel):
|
||||
keep_images: bool = False
|
||||
pass
|
||||
|
||||
|
||||
class PageAssembleModel(BasePageModel):
|
||||
@ -174,11 +174,4 @@ class PageAssembleModel(BasePageModel):
|
||||
elements=elements, headers=headers, body=body
|
||||
)
|
||||
|
||||
# Remove page images (can be disabled)
|
||||
if not self.options.keep_images:
|
||||
page._image_cache = {}
|
||||
|
||||
# Unload backend
|
||||
page._backend.unload()
|
||||
|
||||
yield page
|
||||
|
@ -28,6 +28,7 @@ _log = logging.getLogger(__name__)
|
||||
class BasePipeline(ABC):
|
||||
def __init__(self, pipeline_options: PipelineOptions):
|
||||
self.pipeline_options = pipeline_options
|
||||
self.keep_images = False
|
||||
self.build_pipe: List[Callable] = []
|
||||
self.enrichment_pipe: List[BaseEnrichmentModel] = []
|
||||
|
||||
@ -40,7 +41,7 @@ class BasePipeline(ABC):
|
||||
conv_res, "pipeline_total", scope=ProfilingScope.DOCUMENT
|
||||
):
|
||||
# These steps are building and assembling the structure of the
|
||||
# output DoclingDocument
|
||||
# output DoclingDocument.
|
||||
conv_res = self._build_document(conv_res)
|
||||
conv_res = self._assemble_document(conv_res)
|
||||
# From this stage, all operations should rely only on conv_res.output
|
||||
@ -50,6 +51,8 @@ class BasePipeline(ABC):
|
||||
conv_res.status = ConversionStatus.FAILURE
|
||||
if raises_on_error:
|
||||
raise e
|
||||
finally:
|
||||
self._unload(conv_res)
|
||||
|
||||
return conv_res
|
||||
|
||||
@ -62,21 +65,22 @@ class BasePipeline(ABC):
|
||||
|
||||
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
|
||||
def _filter_elements(
|
||||
doc: DoclingDocument, model: BaseEnrichmentModel
|
||||
def _prepare_elements(
|
||||
conv_res: ConversionResult, model: BaseEnrichmentModel
|
||||
) -> Iterable[NodeItem]:
|
||||
for element, _level in doc.iterate_items():
|
||||
if model.is_processable(doc=doc, element=element):
|
||||
yield element
|
||||
for doc_element, _level in conv_res.document.iterate_items():
|
||||
prepared_element = model.prepare_element(
|
||||
conv_res=conv_res, element=doc_element
|
||||
)
|
||||
if prepared_element is not None:
|
||||
yield prepared_element
|
||||
|
||||
with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT):
|
||||
for model in self.enrichment_pipe:
|
||||
for element_batch in chunkify(
|
||||
_filter_elements(conv_res.document, model),
|
||||
_prepare_elements(conv_res, model),
|
||||
settings.perf.elements_batch_size,
|
||||
):
|
||||
# TODO: currently we assume the element itself is modified, because
|
||||
# we don't have an interface to save the element back to the document
|
||||
for element in model(
|
||||
doc=conv_res.document, element_batch=element_batch
|
||||
): # Must exhaust!
|
||||
@ -88,6 +92,9 @@ class BasePipeline(ABC):
|
||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||
pass
|
||||
|
||||
def _unload(self, conv_res: ConversionResult):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_default_options(cls) -> PipelineOptions:
|
||||
@ -107,6 +114,10 @@ class BasePipeline(ABC):
|
||||
|
||||
class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
||||
|
||||
def __init__(self, pipeline_options: PipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.keep_backend = False
|
||||
|
||||
def _apply_on_pages(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@ -148,7 +159,14 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
||||
pipeline_pages = self._apply_on_pages(conv_res, init_pages)
|
||||
|
||||
for p in pipeline_pages: # Must exhaust!
|
||||
pass
|
||||
|
||||
# Cleanup cached images
|
||||
if not self.keep_images:
|
||||
p._image_cache = {}
|
||||
|
||||
# Cleanup page backends
|
||||
if not self.keep_backend and p._backend is not None:
|
||||
p._backend.unload()
|
||||
|
||||
end_batch_time = time.monotonic()
|
||||
total_elapsed_time += end_batch_time - start_batch_time
|
||||
@ -177,10 +195,15 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
||||
)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Always unload the PDF backend, even in case of failure
|
||||
if conv_res.input._backend:
|
||||
conv_res.input._backend.unload()
|
||||
return conv_res
|
||||
|
||||
def _unload(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
for page in conv_res.pages:
|
||||
if page._backend is not None:
|
||||
page._backend.unload()
|
||||
|
||||
if conv_res.input._backend:
|
||||
conv_res.input._backend.unload()
|
||||
|
||||
return conv_res
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem
|
||||
|
||||
@ -17,6 +17,7 @@ from docling.datamodel.pipeline_options import (
|
||||
TesseractCliOcrOptions,
|
||||
TesseractOcrOptions,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.ds_glm_model import GlmModel, GlmOptions
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
@ -50,7 +51,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
else:
|
||||
self.artifacts_path = Path(pipeline_options.artifacts_path)
|
||||
|
||||
keep_images = (
|
||||
self.keep_images = (
|
||||
self.pipeline_options.generate_page_images
|
||||
or self.pipeline_options.generate_picture_images
|
||||
or self.pipeline_options.generate_table_images
|
||||
@ -87,7 +88,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
# Page assemble
|
||||
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
|
||||
PageAssembleModel(options=PageAssembleOptions()),
|
||||
]
|
||||
|
||||
self.enrichment_pipe = [
|
||||
|
88
docs/examples/develop_formula_understanding.py
Normal file
88
docs/examples/develop_formula_understanding.py
Normal file
@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from docling_core.types.doc import DocItemLabel, DoclingDocument, NodeItem, TextItem
|
||||
|
||||
from docling.datamodel.base_models import InputFormat, ItemAndImageEnrichmentElement
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
||||
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
|
||||
|
||||
|
||||
class ExampleFormulaUnderstandingPipelineOptions(PdfPipelineOptions):
|
||||
do_formula_understanding: bool = True
|
||||
|
||||
|
||||
# A new enrichment model using both the document element and its image as input
|
||||
class ExampleFormulaUnderstandingEnrichmentModel(BaseItemAndImageEnrichmentModel):
|
||||
images_scale = 2.6
|
||||
|
||||
def __init__(self, enabled: bool):
|
||||
self.enabled = enabled
|
||||
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
return (
|
||||
self.enabled
|
||||
and isinstance(element, TextItem)
|
||||
and element.label == DocItemLabel.FORMULA
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
doc: DoclingDocument,
|
||||
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
||||
) -> Iterable[NodeItem]:
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
for enrich_element in element_batch:
|
||||
enrich_element.image.show()
|
||||
|
||||
yield enrich_element.item
|
||||
|
||||
|
||||
# How the pipeline can be extended.
|
||||
class ExampleFormulaUnderstandingPipeline(StandardPdfPipeline):
|
||||
|
||||
def __init__(self, pipeline_options: ExampleFormulaUnderstandingPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: ExampleFormulaUnderstandingPipelineOptions
|
||||
|
||||
self.enrichment_pipe = [
|
||||
ExampleFormulaUnderstandingEnrichmentModel(
|
||||
enabled=self.pipeline_options.do_formula_understanding
|
||||
)
|
||||
]
|
||||
|
||||
if self.pipeline_options.do_formula_understanding:
|
||||
self.keep_backend = True
|
||||
|
||||
@classmethod
|
||||
def get_default_options(cls) -> ExampleFormulaUnderstandingPipelineOptions:
|
||||
return ExampleFormulaUnderstandingPipelineOptions()
|
||||
|
||||
|
||||
# Example main. In the final version, we simply have to set do_formula_understanding to true.
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
input_doc_path = Path("./tests/data/2203.01017v2.pdf")
|
||||
|
||||
pipeline_options = ExampleFormulaUnderstandingPipelineOptions()
|
||||
pipeline_options.do_formula_understanding = True
|
||||
|
||||
doc_converter = DocumentConverter(
|
||||
format_options={
|
||||
InputFormat.PDF: PdfFormatOption(
|
||||
pipeline_cls=ExampleFormulaUnderstandingPipeline,
|
||||
pipeline_options=pipeline_options,
|
||||
)
|
||||
}
|
||||
)
|
||||
result = doc_converter.convert(input_doc_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user