Ensure all models work only on valid pages (#158)

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-10-18 08:54:06 +02:00 committed by GitHub
parent 034a411057
commit a00c937e19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 413 additions and 376 deletions

View File

@ -202,6 +202,7 @@ class GlmModel:
page_dimensions = [ page_dimensions = [
PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width) PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width)
for p in conv_res.pages for p in conv_res.pages
if p.size is not None
] ]
ds_doc: DsDocument = DsDocument( ds_doc: DsDocument = DsDocument(

View File

@ -41,7 +41,9 @@ class EasyOcrModel(BaseOcrModel):
for page in page_batch: for page in page_batch:
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
ocr_rects = self.get_ocr_rects(page) ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = [] all_ocr_cells = []

View File

@ -273,6 +273,10 @@ class LayoutModel(BasePageModel):
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch: for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.size is not None assert page.size is not None
clusters = [] clusters = []

View File

@ -54,7 +54,11 @@ class PageAssembleModel(BasePageModel):
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch: for page in page_batch:
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.predictions.layout is not None assert page.predictions.layout is not None
# assembles some JSON output page by page. # assembles some JSON output page by page.
elements: List[PageElement] = [] elements: List[PageElement] = []
@ -108,9 +112,11 @@ class PageAssembleModel(BasePageModel):
elif cluster.label == LayoutModel.FIGURE_LABEL: elif cluster.label == LayoutModel.FIGURE_LABEL:
fig = None fig = None
if page.predictions.figures_classification: if page.predictions.figures_classification:
fig = page.predictions.figures_classification.figure_map.get( fig = (
page.predictions.figures_classification.figure_map.get(
cluster.id, None cluster.id, None
) )
)
if ( if (
not fig not fig
): # fallback: add figure without classification, if it isn't present ): # fallback: add figure without classification, if it isn't present
@ -132,7 +138,9 @@ class PageAssembleModel(BasePageModel):
cluster.id, None cluster.id, None
) )
) )
if not equation: # fallback: add empty formula, if it isn't present if (
not equation
): # fallback: add empty formula, if it isn't present
text = self.sanitize_text( text = self.sanitize_text(
[ [
cell.text.replace("\x02", "-").strip() cell.text.replace("\x02", "-").strip()

View File

@ -17,6 +17,10 @@ class PagePreprocessingModel(BasePageModel):
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch: for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
page = self._populate_page_images(page) page = self._populate_page_images(page)
page = self._parse_page_cells(page) page = self._parse_page_cells(page)
yield page yield page

View File

@ -71,6 +71,10 @@ class TableStructureModel(BasePageModel):
for page in page_batch: for page in page_batch:
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.predictions.layout is not None assert page.predictions.layout is not None
assert page.size is not None assert page.size is not None
@ -98,13 +102,16 @@ class TableStructureModel(BasePageModel):
for cluster, _ in in_tables: for cluster, _ in in_tables:
if c.bbox.area() > 0: if c.bbox.area() > 0:
if ( if (
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area() c.bbox.intersection_area_with(cluster.bbox)
/ c.bbox.area()
> 0.2 > 0.2
): ):
# Only allow non empty stings (spaces) into the cells of a table # Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0: if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c) new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale) new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
)
tokens.append(new_cell.model_dump()) tokens.append(new_cell.model_dump())
@ -154,7 +161,9 @@ class TableStructureModel(BasePageModel):
label=DocItemLabel.TABLE, label=DocItemLabel.TABLE,
) )
page.predictions.tablestructure.table_map[table_cluster.id] = tbl page.predictions.tablestructure.table_map[table_cluster.id] = (
tbl
)
# For debugging purposes: # For debugging purposes:
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values()) # self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())

View File

@ -110,7 +110,9 @@ class TesseractOcrCliModel(BaseOcrModel):
for page in page_batch: for page in page_batch:
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
ocr_rects = self.get_ocr_rects(page) ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = [] all_ocr_cells = []
@ -122,7 +124,9 @@ class TesseractOcrCliModel(BaseOcrModel):
scale=self.scale, cropbox=ocr_rect scale=self.scale, cropbox=ocr_rect
) )
with tempfile.NamedTemporaryFile(suffix=".png", mode="w") as image_file: with tempfile.NamedTemporaryFile(
suffix=".png", mode="w"
) as image_file:
fname = image_file.name fname = image_file.name
high_res_image.save(fname) high_res_image.save(fname)

View File

@ -69,6 +69,9 @@ class TesseractOcrModel(BaseOcrModel):
for page in page_batch: for page in page_batch:
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert self.reader is not None assert self.reader is not None
ocr_rects = self.get_ocr_rects(page) ocr_rects = self.get_ocr_rects(page)
@ -84,7 +87,9 @@ class TesseractOcrModel(BaseOcrModel):
# Retrieve text snippets with their bounding boxes # Retrieve text snippets with their bounding boxes
self.reader.SetImage(high_res_image) self.reader.SetImage(high_res_image)
boxes = self.reader.GetComponentImages(self.reader_RIL.TEXTLINE, True) boxes = self.reader.GetComponentImages(
self.reader_RIL.TEXTLINE, True
)
cells = [] cells = []
for ix, (im, box, _, _) in enumerate(boxes): for ix, (im, box, _, _) in enumerate(boxes):

View File

@ -134,7 +134,7 @@ class StandardPdfPipeline(PaginatedPipeline):
all_body = [] all_body = []
for p in conv_res.pages: for p in conv_res.pages:
assert p.assembled is not None if p.assembled is not None:
for el in p.assembled.body: for el in p.assembled.body:
all_body.append(el) all_body.append(el)
for el in p.assembled.headers: for el in p.assembled.headers:

View File

@ -126,7 +126,7 @@ input_files = [
] ]
# Directly pass list of files or streams to `convert_all` # Directly pass list of files or streams to `convert_all`
conv_results_iter = doc_converter.convert_all(input_files) # previously `convert_batch` conv_results_iter = doc_converter.convert_all(input_files) # previously `convert`
``` ```
Through the `raises_on_error` argument, you can also control if the conversion should raise exceptions when first Through the `raises_on_error` argument, you can also control if the conversion should raise exceptions when first
@ -135,7 +135,7 @@ By default, any error is immediately raised and the conversion aborts (previousl
```python ```python
... ...
conv_results_iter = doc_converter.convert_all(input_files, raises_on_error=False) # previously `convert_batch` conv_results_iter = doc_converter.convert_all(input_files, raises_on_error=False) # previously `convert`
``` ```