Ensure all models work only on valid pages (#158)

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-10-18 08:54:06 +02:00 committed by GitHub
parent 034a411057
commit a00c937e19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 413 additions and 376 deletions

View File

@ -202,6 +202,7 @@ class GlmModel:
page_dimensions = [
PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width)
for p in conv_res.pages
if p.size is not None
]
ds_doc: DsDocument = DsDocument(

View File

@ -41,7 +41,9 @@ class EasyOcrModel(BaseOcrModel):
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = []

View File

@ -273,6 +273,10 @@ class LayoutModel(BasePageModel):
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.size is not None
clusters = []

View File

@ -54,7 +54,11 @@ class PageAssembleModel(BasePageModel):
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.predictions.layout is not None
# assembles some JSON output page by page.
elements: List[PageElement] = []
@ -108,9 +112,11 @@ class PageAssembleModel(BasePageModel):
elif cluster.label == LayoutModel.FIGURE_LABEL:
fig = None
if page.predictions.figures_classification:
fig = page.predictions.figures_classification.figure_map.get(
fig = (
page.predictions.figures_classification.figure_map.get(
cluster.id, None
)
)
if (
not fig
): # fallback: add figure without classification, if it isn't present
@ -132,7 +138,9 @@ class PageAssembleModel(BasePageModel):
cluster.id, None
)
)
if not equation: # fallback: add empty formula, if it isn't present
if (
not equation
): # fallback: add empty formula, if it isn't present
text = self.sanitize_text(
[
cell.text.replace("\x02", "-").strip()

View File

@ -17,6 +17,10 @@ class PagePreprocessingModel(BasePageModel):
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
page = self._populate_page_images(page)
page = self._parse_page_cells(page)
yield page

View File

@ -71,6 +71,10 @@ class TableStructureModel(BasePageModel):
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.predictions.layout is not None
assert page.size is not None
@ -98,13 +102,16 @@ class TableStructureModel(BasePageModel):
for cluster, _ in in_tables:
if c.bbox.area() > 0:
if (
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
c.bbox.intersection_area_with(cluster.bbox)
/ c.bbox.area()
> 0.2
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
)
tokens.append(new_cell.model_dump())
@ -154,7 +161,9 @@ class TableStructureModel(BasePageModel):
label=DocItemLabel.TABLE,
)
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
page.predictions.tablestructure.table_map[table_cluster.id] = (
tbl
)
# For debugging purposes:
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())

View File

@ -110,7 +110,9 @@ class TesseractOcrCliModel(BaseOcrModel):
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = []
@ -122,7 +124,9 @@ class TesseractOcrCliModel(BaseOcrModel):
scale=self.scale, cropbox=ocr_rect
)
with tempfile.NamedTemporaryFile(suffix=".png", mode="w") as image_file:
with tempfile.NamedTemporaryFile(
suffix=".png", mode="w"
) as image_file:
fname = image_file.name
high_res_image.save(fname)

View File

@ -69,6 +69,9 @@ class TesseractOcrModel(BaseOcrModel):
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert self.reader is not None
ocr_rects = self.get_ocr_rects(page)
@ -84,7 +87,9 @@ class TesseractOcrModel(BaseOcrModel):
# Retrieve text snippets with their bounding boxes
self.reader.SetImage(high_res_image)
boxes = self.reader.GetComponentImages(self.reader_RIL.TEXTLINE, True)
boxes = self.reader.GetComponentImages(
self.reader_RIL.TEXTLINE, True
)
cells = []
for ix, (im, box, _, _) in enumerate(boxes):

View File

@ -134,7 +134,7 @@ class StandardPdfPipeline(PaginatedPipeline):
all_body = []
for p in conv_res.pages:
assert p.assembled is not None
if p.assembled is not None:
for el in p.assembled.body:
all_body.append(el)
for el in p.assembled.headers:

View File

@ -126,7 +126,7 @@ input_files = [
]
# Directly pass list of files or streams to `convert_all`
conv_results_iter = doc_converter.convert_all(input_files) # previously `convert_batch`
conv_results_iter = doc_converter.convert_all(input_files) # previously `convert`
```
Through the `raises_on_error` argument, you can also control if the conversion should raise exceptions when first
@ -135,7 +135,7 @@ By default, any error is immediately raised and the conversion aborts (previousl
```python
...
conv_results_iter = doc_converter.convert_all(input_files, raises_on_error=False) # previously `convert_batch`
conv_results_iter = doc_converter.convert_all(input_files, raises_on_error=False) # previously `convert`
```