refactor: allow the usage of backends in the enrich models and generalize the interface (#742)
* fix get image with cropbox Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * allow the usage of backends in the enrich models and generalize the interface Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move logic in BaseTextImageEnrichmentModel Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * renaming Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -28,6 +28,7 @@ _log = logging.getLogger(__name__)
|
||||
class BasePipeline(ABC):
|
||||
def __init__(self, pipeline_options: PipelineOptions):
|
||||
self.pipeline_options = pipeline_options
|
||||
self.keep_images = False
|
||||
self.build_pipe: List[Callable] = []
|
||||
self.enrichment_pipe: List[BaseEnrichmentModel] = []
|
||||
|
||||
@@ -40,7 +41,7 @@ class BasePipeline(ABC):
|
||||
conv_res, "pipeline_total", scope=ProfilingScope.DOCUMENT
|
||||
):
|
||||
# These steps are building and assembling the structure of the
|
||||
# output DoclingDocument
|
||||
# output DoclingDocument.
|
||||
conv_res = self._build_document(conv_res)
|
||||
conv_res = self._assemble_document(conv_res)
|
||||
# From this stage, all operations should rely only on conv_res.output
|
||||
@@ -50,6 +51,8 @@ class BasePipeline(ABC):
|
||||
conv_res.status = ConversionStatus.FAILURE
|
||||
if raises_on_error:
|
||||
raise e
|
||||
finally:
|
||||
self._unload(conv_res)
|
||||
|
||||
return conv_res
|
||||
|
||||
@@ -62,21 +65,22 @@ class BasePipeline(ABC):
|
||||
|
||||
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
|
||||
def _filter_elements(
|
||||
doc: DoclingDocument, model: BaseEnrichmentModel
|
||||
def _prepare_elements(
|
||||
conv_res: ConversionResult, model: BaseEnrichmentModel
|
||||
) -> Iterable[NodeItem]:
|
||||
for element, _level in doc.iterate_items():
|
||||
if model.is_processable(doc=doc, element=element):
|
||||
yield element
|
||||
for doc_element, _level in conv_res.document.iterate_items():
|
||||
prepared_element = model.prepare_element(
|
||||
conv_res=conv_res, element=doc_element
|
||||
)
|
||||
if prepared_element is not None:
|
||||
yield prepared_element
|
||||
|
||||
with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT):
|
||||
for model in self.enrichment_pipe:
|
||||
for element_batch in chunkify(
|
||||
_filter_elements(conv_res.document, model),
|
||||
_prepare_elements(conv_res, model),
|
||||
settings.perf.elements_batch_size,
|
||||
):
|
||||
# TODO: currently we assume the element itself is modified, because
|
||||
# we don't have an interface to save the element back to the document
|
||||
for element in model(
|
||||
doc=conv_res.document, element_batch=element_batch
|
||||
): # Must exhaust!
|
||||
@@ -88,6 +92,9 @@ class BasePipeline(ABC):
|
||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||
pass
|
||||
|
||||
def _unload(self, conv_res: ConversionResult):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_default_options(cls) -> PipelineOptions:
|
||||
@@ -107,6 +114,10 @@ class BasePipeline(ABC):
|
||||
|
||||
class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
||||
|
||||
def __init__(self, pipeline_options: PipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.keep_backend = False
|
||||
|
||||
def _apply_on_pages(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@@ -148,7 +159,14 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
||||
pipeline_pages = self._apply_on_pages(conv_res, init_pages)
|
||||
|
||||
for p in pipeline_pages: # Must exhaust!
|
||||
pass
|
||||
|
||||
# Cleanup cached images
|
||||
if not self.keep_images:
|
||||
p._image_cache = {}
|
||||
|
||||
# Cleanup page backends
|
||||
if not self.keep_backend and p._backend is not None:
|
||||
p._backend.unload()
|
||||
|
||||
end_batch_time = time.monotonic()
|
||||
total_elapsed_time += end_batch_time - start_batch_time
|
||||
@@ -177,10 +195,15 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
||||
)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Always unload the PDF backend, even in case of failure
|
||||
if conv_res.input._backend:
|
||||
conv_res.input._backend.unload()
|
||||
return conv_res
|
||||
|
||||
def _unload(self, conv_res: ConversionResult) -> ConversionResult:
|
||||
for page in conv_res.pages:
|
||||
if page._backend is not None:
|
||||
page._backend.unload()
|
||||
|
||||
if conv_res.input._backend:
|
||||
conv_res.input._backend.unload()
|
||||
|
||||
return conv_res
|
||||
|
||||
|
||||
Reference in New Issue
Block a user