Ensure all models work only on valid pages (#158)
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
034a411057
commit
a00c937e19
@ -202,6 +202,7 @@ class GlmModel:
|
||||
page_dimensions = [
|
||||
PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width)
|
||||
for p in conv_res.pages
|
||||
if p.size is not None
|
||||
]
|
||||
|
||||
ds_doc: DsDocument = DsDocument(
|
||||
|
@ -41,7 +41,9 @@ class EasyOcrModel(BaseOcrModel):
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
all_ocr_cells = []
|
||||
|
@ -273,6 +273,10 @@ class LayoutModel(BasePageModel):
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
assert page.size is not None
|
||||
|
||||
clusters = []
|
||||
|
@ -54,7 +54,11 @@ class PageAssembleModel(BasePageModel):
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
assert page.predictions.layout is not None
|
||||
|
||||
# assembles some JSON output page by page.
|
||||
|
||||
elements: List[PageElement] = []
|
||||
@ -108,9 +112,11 @@ class PageAssembleModel(BasePageModel):
|
||||
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
||||
fig = None
|
||||
if page.predictions.figures_classification:
|
||||
fig = page.predictions.figures_classification.figure_map.get(
|
||||
fig = (
|
||||
page.predictions.figures_classification.figure_map.get(
|
||||
cluster.id, None
|
||||
)
|
||||
)
|
||||
if (
|
||||
not fig
|
||||
): # fallback: add figure without classification, if it isn't present
|
||||
@ -132,7 +138,9 @@ class PageAssembleModel(BasePageModel):
|
||||
cluster.id, None
|
||||
)
|
||||
)
|
||||
if not equation: # fallback: add empty formula, if it isn't present
|
||||
if (
|
||||
not equation
|
||||
): # fallback: add empty formula, if it isn't present
|
||||
text = self.sanitize_text(
|
||||
[
|
||||
cell.text.replace("\x02", "-").strip()
|
||||
|
@ -17,6 +17,10 @@ class PagePreprocessingModel(BasePageModel):
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
page = self._populate_page_images(page)
|
||||
page = self._parse_page_cells(page)
|
||||
yield page
|
||||
|
@ -71,6 +71,10 @@ class TableStructureModel(BasePageModel):
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
|
||||
assert page.predictions.layout is not None
|
||||
assert page.size is not None
|
||||
|
||||
@ -98,13 +102,16 @@ class TableStructureModel(BasePageModel):
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
|
||||
c.bbox.intersection_area_with(cluster.bbox)
|
||||
/ c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
|
||||
new_cell.bbox = new_cell.bbox.scaled(
|
||||
scale=self.scale
|
||||
)
|
||||
|
||||
tokens.append(new_cell.model_dump())
|
||||
|
||||
@ -154,7 +161,9 @@ class TableStructureModel(BasePageModel):
|
||||
label=DocItemLabel.TABLE,
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = (
|
||||
tbl
|
||||
)
|
||||
|
||||
# For debugging purposes:
|
||||
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
|
||||
|
@ -110,7 +110,9 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
all_ocr_cells = []
|
||||
@ -122,7 +124,9 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
scale=self.scale, cropbox=ocr_rect
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", mode="w") as image_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".png", mode="w"
|
||||
) as image_file:
|
||||
fname = image_file.name
|
||||
high_res_image.save(fname)
|
||||
|
||||
|
@ -69,6 +69,9 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
assert self.reader is not None
|
||||
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
@ -84,7 +87,9 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
|
||||
# Retrieve text snippets with their bounding boxes
|
||||
self.reader.SetImage(high_res_image)
|
||||
boxes = self.reader.GetComponentImages(self.reader_RIL.TEXTLINE, True)
|
||||
boxes = self.reader.GetComponentImages(
|
||||
self.reader_RIL.TEXTLINE, True
|
||||
)
|
||||
|
||||
cells = []
|
||||
for ix, (im, box, _, _) in enumerate(boxes):
|
||||
|
@ -134,7 +134,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
all_body = []
|
||||
|
||||
for p in conv_res.pages:
|
||||
assert p.assembled is not None
|
||||
if p.assembled is not None:
|
||||
for el in p.assembled.body:
|
||||
all_body.append(el)
|
||||
for el in p.assembled.headers:
|
||||
|
@ -126,7 +126,7 @@ input_files = [
|
||||
]
|
||||
|
||||
# Directly pass list of files or streams to `convert_all`
|
||||
conv_results_iter = doc_converter.convert_all(input_files) # previously `convert_batch`
|
||||
conv_results_iter = doc_converter.convert_all(input_files) # previously `convert`
|
||||
|
||||
```
|
||||
Through the `raises_on_error` argument, you can also control if the conversion should raise exceptions when first
|
||||
@ -135,7 +135,7 @@ By default, any error is immediately raised and the conversion aborts (previousl
|
||||
|
||||
```python
|
||||
...
|
||||
conv_results_iter = doc_converter.convert_all(input_files, raises_on_error=False) # previously `convert_batch`
|
||||
conv_results_iter = doc_converter.convert_all(input_files, raises_on_error=False) # previously `convert`
|
||||
|
||||
```
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user