feat: Introduce support for GPU Accelerators (#593)
* Upgraded Layout Postprocessing, sending old code back to ERZ Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Implement hierachical cluster layout processing Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Pass nested cluster processing through full pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Pass nested clusters through GLM as payload Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move to_docling_document from ds-glm to this repo Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Clean up imports again Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * feat(Accelerator): Introduce options to control the num_threads and device from API, envvars, CLI. - Introduce the AcceleratorOptions, AcceleratorDevice and use them to set the device where the models run. - Introduce the accelerator_utils with function to decide the device and resolve the AUTO setting. - Refactor the way how the docling-ibm-models are called to match the new init signature of models. - Translate the accelerator options to the specific inputs for third-party models. - Extend the docling CLI with parameters to set the num_threads and device. - Add new unit tests. - Write new example how to use the accelerator options. * fix: Improve the pydantic objects in the pipeline_options and imports. Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * fix: TableStructureModel: Refactor the artifacts path to use the new structure for fast/accurate model Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * Updated test ground-truth Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Updated test ground-truth (again), bugfix for empty layout Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * fix: Do proper check to set the device in EasyOCR, RapidOCR. Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * Rollback changes from main Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update test gt Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Remove unused debug settings Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Review fixes Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Nail the accelerator defaults for MPS Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import ImageDraw
|
||||
|
||||
import docling.utils.layout_utils as lu
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Cell,
|
||||
@@ -17,9 +18,10 @@ from docling.datamodel.base_models import (
|
||||
Page,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import layout_utils as lu
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@@ -46,8 +48,16 @@ class LayoutModel(BasePageModel):
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
|
||||
def __init__(self, artifacts_path: Path):
|
||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
device = decide_device(accelerator_options.device)
|
||||
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
artifact_path=str(artifacts_path),
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
base_threshold=0.6,
|
||||
blacklist_classes={"Form", "Key-Value Region"},
|
||||
)
|
||||
|
||||
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
|
||||
Reference in New Issue
Block a user