import copy import warnings from pathlib import Path from typing import Iterable, Optional, Union import numpy from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor from PIL import ImageDraw from docling.datamodel.base_models import Page, Table, TableStructurePrediction from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( AcceleratorDevice, AcceleratorOptions, TableFormerMode, TableStructureOptions, ) from docling.datamodel.settings import settings from docling.models.base_model import BasePageModel from docling.utils.accelerator_utils import decide_device from docling.utils.profiling import TimeRecorder class TableStructureModel(BasePageModel): _model_repo_folder = "ds4sd--docling-models" _model_path = "model_artifacts/tableformer" def __init__( self, enabled: bool, artifacts_path: Optional[Path], options: TableStructureOptions, accelerator_options: AcceleratorOptions, ): self.options = options self.do_cell_matching = self.options.do_cell_matching self.mode = self.options.mode self.enabled = enabled if self.enabled: if artifacts_path is None: artifacts_path = self.download_models() / self._model_path else: # will become the default in the future if (artifacts_path / self._model_repo_folder).exists(): artifacts_path = ( artifacts_path / self._model_repo_folder / self._model_path ) elif (artifacts_path / self._model_path).exists(): warnings.warn( "The usage of artifacts_path containing directly " f"{self._model_path} is deprecated. Please point " "the artifacts_path to the parent containing " f"the {self._model_repo_folder} folder.", DeprecationWarning, stacklevel=3, ) artifacts_path = artifacts_path / self._model_path if self.mode == TableFormerMode.ACCURATE: artifacts_path = artifacts_path / "accurate" else: artifacts_path = artifacts_path / "fast" # Third Party import docling_ibm_models.tableformer.common as c device = decide_device(accelerator_options.device) # Disable MPS here, until we know why it makes things slower. if device == AcceleratorDevice.MPS.value: device = AcceleratorDevice.CPU.value self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json") self.tm_config["model"]["save_dir"] = artifacts_path self.tm_model_type = self.tm_config["model"]["type"] self.tf_predictor = TFPredictor( self.tm_config, device, accelerator_options.num_threads ) self.scale = 2.0 # Scale up table input images to 144 dpi @staticmethod def download_models( local_dir: Optional[Path] = None, force: bool = False, progress: bool = False ) -> Path: from huggingface_hub import snapshot_download from huggingface_hub.utils import disable_progress_bars if not progress: disable_progress_bars() download_path = snapshot_download( repo_id="ds4sd/docling-models", force_download=force, local_dir=local_dir, revision="v2.2.0", ) return Path(download_path) def draw_table_and_cells( self, conv_res: ConversionResult, page: Page, tbl_list: Iterable[Table], show: bool = False, ): assert page._backend is not None assert page.size is not None image = ( page._backend.get_page_image() ) # make new image to avoid drawing on the saved ones scale_x = image.width / page.size.width scale_y = image.height / page.size.height draw = ImageDraw.Draw(image) for table_element in tbl_list: x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple() y0 *= scale_x y1 *= scale_y x0 *= scale_x x1 *= scale_x draw.rectangle([(x0, y0), (x1, y1)], outline="red") for cell in table_element.cluster.cells: x0, y0, x1, y1 = cell.bbox.as_tuple() x0 *= scale_x x1 *= scale_x y0 *= scale_x y1 *= scale_y draw.rectangle([(x0, y0), (x1, y1)], outline="green") for tc in table_element.table_cells: if tc.bbox is not None: x0, y0, x1, y1 = tc.bbox.as_tuple() x0 *= scale_x x1 *= scale_x y0 *= scale_x y1 *= scale_y if tc.column_header: width = 3 else: width = 1 draw.rectangle([(x0, y0), (x1, y1)], outline="blue", width=width) draw.text( (x0 + 3, y0 + 3), text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}", fill="black", ) if show: image.show() else: out_path: Path = ( Path(settings.debug.debug_output_path) / f"debug_{conv_res.input.file.stem}" ) out_path.mkdir(parents=True, exist_ok=True) out_file = out_path / f"table_struct_page_{page.page_no:05}.png" image.save(str(out_file), format="png") def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: if not self.enabled: yield from page_batch return for page in page_batch: assert page._backend is not None if not page._backend.is_valid(): yield page else: with TimeRecorder(conv_res, "table_structure"): assert page.predictions.layout is not None assert page.size is not None page.predictions.tablestructure = ( TableStructurePrediction() ) # dummy in_tables = [ ( cluster, [ round(cluster.bbox.l) * self.scale, round(cluster.bbox.t) * self.scale, round(cluster.bbox.r) * self.scale, round(cluster.bbox.b) * self.scale, ], ) for cluster in page.predictions.layout.clusters if cluster.label in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] ] if not len(in_tables): yield page continue page_input = { "width": page.size.width * self.scale, "height": page.size.height * self.scale, "image": numpy.asarray(page.get_image(scale=self.scale)), } table_clusters, table_bboxes = zip(*in_tables) if len(table_bboxes): for table_cluster, tbl_box in in_tables: tokens = [] for c in table_cluster.cells: # Only allow non empty stings (spaces) into the cells of a table if len(c.text.strip()) > 0: new_cell = copy.deepcopy(c) new_cell.bbox = new_cell.bbox.scaled( scale=self.scale ) tokens.append(new_cell.model_dump()) page_input["tokens"] = tokens tf_output = self.tf_predictor.multi_table_predict( page_input, [tbl_box], do_matching=self.do_cell_matching ) table_out = tf_output[0] table_cells = [] for element in table_out["tf_responses"]: if not self.do_cell_matching: the_bbox = BoundingBox.model_validate( element["bbox"] ).scaled(1 / self.scale) text_piece = page._backend.get_text_in_rect( the_bbox ) element["bbox"]["token"] = text_piece tc = TableCell.model_validate(element) if self.do_cell_matching and tc.bbox is not None: tc.bbox = tc.bbox.scaled(1 / self.scale) table_cells.append(tc) assert "predict_details" in table_out # Retrieving cols/rows, after post processing: num_rows = table_out["predict_details"].get("num_rows", 0) num_cols = table_out["predict_details"].get("num_cols", 0) otsl_seq = ( table_out["predict_details"] .get("prediction", {}) .get("rs_seq", []) ) tbl = Table( otsl_seq=otsl_seq, table_cells=table_cells, num_rows=num_rows, num_cols=num_cols, id=table_cluster.id, page_no=page.page_no, cluster=table_cluster, label=table_cluster.label, ) page.predictions.tablestructure.table_map[ table_cluster.id ] = tbl # For debugging purposes: if settings.debug.visualize_tables: self.draw_table_and_cells( conv_res, page, page.predictions.tablestructure.table_map.values(), ) yield page