feat: new artifacts path and CLI utility (#876)
* fix artifacts path Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * add docling-models utility Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * missing formatting Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename utility to docling-tools Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * rename download methods and deprecation warnings Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * propagate artifacts path usage for ocr models Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * move function to utils Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove unused file Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * update docs Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * simplify downloading specific model(s) Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> * minor refactor Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,6 +18,7 @@ from docling.datamodel.pipeline_options import (
|
||||
TesseractCliOcrOptions,
|
||||
TesseractOcrOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||
from docling.models.document_picture_classifier import (
|
||||
@@ -37,23 +39,23 @@ from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.model_downloader import download_models
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StandardPdfPipeline(PaginatedPipeline):
|
||||
_layout_model_path = "model_artifacts/layout"
|
||||
_table_model_path = "model_artifacts/tableformer"
|
||||
_layout_model_path = LayoutModel._model_path
|
||||
_table_model_path = TableStructureModel._model_path
|
||||
|
||||
def __init__(self, pipeline_options: PdfPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
self.pipeline_options: PdfPipelineOptions
|
||||
|
||||
if pipeline_options.artifacts_path is None:
|
||||
self.artifacts_path = self.download_models_hf()
|
||||
else:
|
||||
self.artifacts_path = Path(pipeline_options.artifacts_path)
|
||||
artifacts_path: Optional[Path] = None
|
||||
if pipeline_options.artifacts_path is not None:
|
||||
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
|
||||
|
||||
self.keep_images = (
|
||||
self.pipeline_options.generate_page_images
|
||||
@@ -63,7 +65,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
|
||||
self.glm_model = GlmModel(options=GlmOptions())
|
||||
|
||||
if (ocr_model := self.get_ocr_model()) is None:
|
||||
if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None:
|
||||
raise RuntimeError(
|
||||
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
|
||||
)
|
||||
@@ -79,15 +81,13 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
ocr_model,
|
||||
# Layout model
|
||||
LayoutModel(
|
||||
artifacts_path=self.artifacts_path
|
||||
/ StandardPdfPipeline._layout_model_path,
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
# Table structure model
|
||||
TableStructureModel(
|
||||
enabled=pipeline_options.do_table_structure,
|
||||
artifacts_path=self.artifacts_path
|
||||
/ StandardPdfPipeline._table_model_path,
|
||||
artifacts_path=artifacts_path,
|
||||
options=pipeline_options.table_structure_options,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
@@ -101,7 +101,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
CodeFormulaModel(
|
||||
enabled=pipeline_options.do_code_enrichment
|
||||
or pipeline_options.do_formula_enrichment,
|
||||
artifacts_path=pipeline_options.artifacts_path,
|
||||
artifacts_path=artifacts_path,
|
||||
options=CodeFormulaModelOptions(
|
||||
do_code_enrichment=pipeline_options.do_code_enrichment,
|
||||
do_formula_enrichment=pipeline_options.do_formula_enrichment,
|
||||
@@ -111,7 +111,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
# Document Picture Classifier
|
||||
DocumentPictureClassifier(
|
||||
enabled=pipeline_options.do_picture_classification,
|
||||
artifacts_path=pipeline_options.artifacts_path,
|
||||
artifacts_path=artifacts_path,
|
||||
options=DocumentPictureClassifierOptions(),
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
@@ -127,23 +127,24 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/docling-models",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v2.1.0",
|
||||
warnings.warn(
|
||||
"The usage of StandardPdfPipeline.download_models_hf() is deprecated "
|
||||
"use instead the utility `docling-tools models download`, or "
|
||||
"the upstream method docling.utils.models_downloader.download_all()",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
output_dir = download_models(output_dir=local_dir, force=force, progress=False)
|
||||
return output_dir
|
||||
|
||||
def get_ocr_model(self) -> Optional[BaseOcrModel]:
|
||||
def get_ocr_model(
|
||||
self, artifacts_path: Optional[Path] = None
|
||||
) -> Optional[BaseOcrModel]:
|
||||
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
|
||||
return EasyOcrModel(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
artifacts_path=artifacts_path,
|
||||
options=self.pipeline_options.ocr_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user