diff --git a/docling/cli/main.py b/docling/cli/main.py index 6ba0d61..7f0f20b 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -9,6 +9,7 @@ import warnings from pathlib import Path from typing import Annotated, Dict, Iterable, List, Optional, Type +import rich.table import typer from docling_core.types.doc import ImageRefMode from docling_core.utils.file import resolve_source_to_path @@ -30,18 +31,14 @@ from docling.datamodel.pipeline_options import ( AcceleratorDevice, AcceleratorOptions, EasyOcrOptions, - OcrEngine, - OcrMacOptions, OcrOptions, PdfBackend, PdfPipelineOptions, - RapidOcrOptions, TableFormerMode, - TesseractCliOcrOptions, - TesseractOcrOptions, ) from docling.datamodel.settings import settings from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption +from docling.models.factories import get_ocr_factory warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch") warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr") @@ -49,8 +46,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr _log = logging.getLogger(__name__) from rich.console import Console +console = Console() err_console = Console(stderr=True) +ocr_factory_internal = get_ocr_factory(allow_external_plugins=False) +ocr_engines_enum_internal = ocr_factory_internal.get_enum() app = typer.Typer( name="Docling", @@ -78,6 +78,24 @@ def version_callback(value: bool): raise typer.Exit() +def show_external_plugins_callback(value: bool): + if value: + ocr_factory_all = get_ocr_factory(allow_external_plugins=True) + table = rich.table.Table(title="Available OCR engines") + table.add_column("Name", justify="right") + table.add_column("Plugin") + table.add_column("Package") + for meta in ocr_factory_all.registered_meta.values(): + if not meta.module.startswith("docling."): + table.add_row( + f"[bold]{meta.kind}[/bold]", + meta.plugin_name, + meta.module.split(".")[0], + ) + rich.print(table) + raise typer.Exit() + + def export_documents( conv_results: Iterable[ConversionResult], output_dir: Path, @@ -196,8 +214,16 @@ def convert( ), ] = False, ocr_engine: Annotated[ - OcrEngine, typer.Option(..., help="The OCR engine to use.") - ] = OcrEngine.EASYOCR, + str, + typer.Option( + ..., + help=( + f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: " + f"{', '.join((o.value for o in ocr_engines_enum_internal))}. " + f"Use the option --show-external-plugins to see the options allowed with external plugins." + ), + ), + ] = EasyOcrOptions.kind, ocr_lang: Annotated[ Optional[str], typer.Option( @@ -241,6 +267,21 @@ def convert( ..., help="Must be enabled when using models connecting to remote services." ), ] = False, + allow_external_plugins: Annotated[ + bool, + typer.Option( + ..., help="Must be enabled for loading modules from third-party plugins." + ), + ] = False, + show_external_plugins: Annotated[ + bool, + typer.Option( + ..., + help="List the third-party plugins which are available when the option --allow-external-plugins is set.", + callback=show_external_plugins_callback, + is_eager=True, + ), + ] = False, abort_on_error: Annotated[ bool, typer.Option( @@ -368,18 +409,11 @@ def convert( export_txt = OutputFormat.TEXT in to_formats export_doctags = OutputFormat.DOCTAGS in to_formats - if ocr_engine == OcrEngine.EASYOCR: - ocr_options: OcrOptions = EasyOcrOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.TESSERACT_CLI: - ocr_options = TesseractCliOcrOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.TESSERACT: - ocr_options = TesseractOcrOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.OCRMAC: - ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.RAPIDOCR: - ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr) - else: - raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}") + ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins) + ocr_options: OcrOptions = ocr_factory.create_options( # type: ignore + kind=ocr_engine, + force_full_page_ocr=force_ocr, + ) ocr_lang_list = _split_list(ocr_lang) if ocr_lang_list is not None: @@ -387,6 +421,7 @@ def convert( accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) pipeline_options = PdfPipelineOptions( + allow_external_plugins=allow_external_plugins, enable_remote_services=enable_remote_services, accelerator_options=accelerator_options, do_ocr=ocr, diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index ee9985a..d28b582 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,10 +1,9 @@ import logging import os import re -import warnings from enum import Enum from pathlib import Path -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, Union from pydantic import ( AnyUrl, @@ -13,13 +12,8 @@ from pydantic import ( Field, field_validator, model_validator, - validator, -) -from pydantic_settings import ( - BaseSettings, - PydanticBaseSettingsSource, - SettingsConfigDict, ) +from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import deprecated _log = logging.getLogger(__name__) @@ -83,6 +77,12 @@ class AcceleratorOptions(BaseSettings): return data +class BaseOptions(BaseModel): + """Base class for options.""" + + kind: ClassVar[str] + + class TableFormerMode(str, Enum): """Modes for the TableFormer model.""" @@ -102,10 +102,9 @@ class TableStructureOptions(BaseModel): mode: TableFormerMode = TableFormerMode.ACCURATE -class OcrOptions(BaseModel): +class OcrOptions(BaseOptions): """OCR options.""" - kind: str lang: List[str] force_full_page_ocr: bool = False # If enabled a full page OCR is always applied bitmap_area_threshold: float = ( @@ -116,7 +115,7 @@ class OcrOptions(BaseModel): class RapidOcrOptions(OcrOptions): """Options for the RapidOCR engine.""" - kind: Literal["rapidocr"] = "rapidocr" + kind: ClassVar[Literal["rapidocr"]] = "rapidocr" # English and chinese are the most commly used models and have been tested with RapidOCR. lang: List[str] = [ @@ -155,7 +154,7 @@ class RapidOcrOptions(OcrOptions): class EasyOcrOptions(OcrOptions): """Options for the EasyOCR engine.""" - kind: Literal["easyocr"] = "easyocr" + kind: ClassVar[Literal["easyocr"]] = "easyocr" lang: List[str] = ["fr", "de", "es", "en"] use_gpu: Optional[bool] = None @@ -175,7 +174,7 @@ class EasyOcrOptions(OcrOptions): class TesseractCliOcrOptions(OcrOptions): """Options for the TesseractCli engine.""" - kind: Literal["tesseract"] = "tesseract" + kind: ClassVar[Literal["tesseract"]] = "tesseract" lang: List[str] = ["fra", "deu", "spa", "eng"] tesseract_cmd: str = "tesseract" path: Optional[str] = None @@ -188,7 +187,7 @@ class TesseractCliOcrOptions(OcrOptions): class TesseractOcrOptions(OcrOptions): """Options for the Tesseract engine.""" - kind: Literal["tesserocr"] = "tesserocr" + kind: ClassVar[Literal["tesserocr"]] = "tesserocr" lang: List[str] = ["fra", "deu", "spa", "eng"] path: Optional[str] = None @@ -200,7 +199,7 @@ class TesseractOcrOptions(OcrOptions): class OcrMacOptions(OcrOptions): """Options for the Mac OCR engine.""" - kind: Literal["ocrmac"] = "ocrmac" + kind: ClassVar[Literal["ocrmac"]] = "ocrmac" lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"] recognition: str = "accurate" framework: str = "vision" @@ -210,8 +209,7 @@ class OcrMacOptions(OcrOptions): ) -class PictureDescriptionBaseOptions(BaseModel): - kind: str +class PictureDescriptionBaseOptions(BaseOptions): batch_size: int = 8 scale: float = 2 @@ -221,7 +219,7 @@ class PictureDescriptionBaseOptions(BaseModel): class PictureDescriptionApiOptions(PictureDescriptionBaseOptions): - kind: Literal["api"] = "api" + kind: ClassVar[Literal["api"]] = "api" url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions") headers: Dict[str, str] = {} @@ -233,7 +231,7 @@ class PictureDescriptionApiOptions(PictureDescriptionBaseOptions): class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions): - kind: Literal["vlm"] = "vlm" + kind: ClassVar[Literal["vlm"]] = "vlm" repo_id: str prompt: str = "Describe this image in a few sentences." @@ -305,6 +303,7 @@ class PdfBackend(str, Enum): # Define an enum for the ocr engines +@deprecated("Use ocr_factory.registered_enum") class OcrEngine(str, Enum): """Enum of valid OCR engines.""" @@ -324,6 +323,7 @@ class PipelineOptions(BaseModel): document_timeout: Optional[float] = None accelerator_options: AcceleratorOptions = AcceleratorOptions() enable_remote_services: bool = False + allow_external_plugins: bool = False class PaginatedPipelineOptions(PipelineOptions): @@ -359,17 +359,10 @@ class PdfPipelineOptions(PaginatedPipelineOptions): # If True, text from backend will be used instead of generated text table_structure_options: TableStructureOptions = TableStructureOptions() - ocr_options: Union[ - EasyOcrOptions, - TesseractCliOcrOptions, - TesseractOcrOptions, - OcrMacOptions, - RapidOcrOptions, - ] = Field(EasyOcrOptions(), discriminator="kind") - picture_description_options: Annotated[ - Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions], - Field(discriminator="kind"), - ] = smolvlm_picture_description + ocr_options: OcrOptions = EasyOcrOptions() + picture_description_options: PictureDescriptionBaseOptions = ( + smolvlm_picture_description + ) images_scale: float = 1.0 generate_page_images: bool = False diff --git a/docling/models/base_model.py b/docling/models/base_model.py index 9cdc0ec..712d329 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -1,14 +1,22 @@ from abc import ABC, abstractmethod -from typing import Any, Generic, Iterable, Optional +from typing import Any, Generic, Iterable, Optional, Protocol, Type from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from typing_extensions import TypeVar from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.settings import settings +class BaseModelWithOptions(Protocol): + @classmethod + def get_options_type(cls) -> Type[BaseOptions]: ... + + def __init__(self, *, options: BaseOptions, **kwargs): ... + + class BasePageModel(ABC): @abstractmethod def __call__( diff --git a/docling/models/base_ocr_model.py b/docling/models/base_ocr_model.py index 1c82264..c823580 100644 --- a/docling/models/base_ocr_model.py +++ b/docling/models/base_ocr_model.py @@ -2,7 +2,7 @@ import copy import logging from abc import abstractmethod from pathlib import Path -from typing import Iterable, List +from typing import Iterable, List, Optional, Type import numpy as np from docling_core.types.doc import BoundingBox, CoordOrigin @@ -13,15 +13,22 @@ from scipy.ndimage import binary_dilation, find_objects, label from docling.datamodel.base_models import Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import OcrOptions +from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseModelWithOptions, BasePageModel _log = logging.getLogger(__name__) -class BaseOcrModel(BasePageModel): - def __init__(self, enabled: bool, options: OcrOptions): +class BaseOcrModel(BasePageModel, BaseModelWithOptions): + def __init__( + self, + *, + enabled: bool, + artifacts_path: Optional[Path], + options: OcrOptions, + accelerator_options: AcceleratorOptions, + ): self.enabled = enabled self.options = options @@ -186,3 +193,8 @@ class BaseOcrModel(BasePageModel): self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: pass + + @classmethod + @abstractmethod + def get_options_type(cls) -> Type[OcrOptions]: + pass diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index 232b6cc..13eb33c 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -2,7 +2,7 @@ import logging import warnings import zipfile from pathlib import Path -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Type import numpy from docling_core.types.doc import BoundingBox, CoordOrigin @@ -14,6 +14,7 @@ from docling.datamodel.pipeline_options import ( AcceleratorDevice, AcceleratorOptions, EasyOcrOptions, + OcrOptions, ) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel @@ -34,7 +35,12 @@ class EasyOcrModel(BaseOcrModel): options: EasyOcrOptions, accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: EasyOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -180,3 +186,7 @@ class EasyOcrModel(BaseOcrModel): self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return EasyOcrOptions diff --git a/docling/models/factories/__init__.py b/docling/models/factories/__init__.py new file mode 100644 index 0000000..9a3308e --- /dev/null +++ b/docling/models/factories/__init__.py @@ -0,0 +1,27 @@ +import logging +from functools import lru_cache + +from docling.models.factories.ocr_factory import OcrFactory +from docling.models.factories.picture_description_factory import ( + PictureDescriptionFactory, +) + +logger = logging.getLogger(__name__) + + +@lru_cache() +def get_ocr_factory(allow_external_plugins: bool = False) -> OcrFactory: + factory = OcrFactory() + factory.load_from_plugins(allow_external_plugins=allow_external_plugins) + logger.info("Registered ocr engines: %r", factory.registered_kind) + return factory + + +@lru_cache() +def get_picture_description_factory( + allow_external_plugins: bool = False, +) -> PictureDescriptionFactory: + factory = PictureDescriptionFactory() + factory.load_from_plugins(allow_external_plugins=allow_external_plugins) + logger.info("Registered picture descriptions: %r", factory.registered_kind) + return factory diff --git a/docling/models/factories/base_factory.py b/docling/models/factories/base_factory.py new file mode 100644 index 0000000..542fc7e --- /dev/null +++ b/docling/models/factories/base_factory.py @@ -0,0 +1,122 @@ +import enum +import logging +from abc import ABCMeta +from typing import Generic, Optional, Type, TypeVar + +from pluggy import PluginManager +from pydantic import BaseModel + +from docling.datamodel.pipeline_options import BaseOptions +from docling.models.base_model import BaseModelWithOptions + +A = TypeVar("A", bound=BaseModelWithOptions) + + +logger = logging.getLogger(__name__) + + +class FactoryMeta(BaseModel): + kind: str + plugin_name: str + module: str + + +class BaseFactory(Generic[A], metaclass=ABCMeta): + default_plugin_name = "docling" + + def __init__(self, plugin_attr_name: str, plugin_name=default_plugin_name): + self.plugin_name = plugin_name + self.plugin_attr_name = plugin_attr_name + + self._classes: dict[Type[BaseOptions], Type[A]] = {} + self._meta: dict[Type[BaseOptions], FactoryMeta] = {} + + @property + def registered_kind(self) -> list[str]: + return list(opt.kind for opt in self._classes.keys()) + + def get_enum(self) -> enum.Enum: + return enum.Enum( + self.plugin_attr_name + "_enum", + names={kind: kind for kind in self.registered_kind}, + type=str, + module=__name__, + ) + + @property + def classes(self): + return self._classes + + @property + def registered_meta(self): + return self._meta + + def create_instance(self, options: BaseOptions, **kwargs) -> A: + try: + _cls = self._classes[type(options)] + return _cls(options=options, **kwargs) + except KeyError: + raise RuntimeError(self._err_msg_on_class_not_found(options.kind)) + + def create_options(self, kind: str, *args, **kwargs) -> BaseOptions: + for opt_cls, _ in self._classes.items(): + if opt_cls.kind == kind: + return opt_cls(*args, **kwargs) + raise RuntimeError(self._err_msg_on_class_not_found(kind)) + + def _err_msg_on_class_not_found(self, kind: str): + msg = [] + + for opt, cls in self._classes.items(): + msg.append(f"\t{opt.kind!r} => {cls!r}") + + msg_str = "\n".join(msg) + + return f"No class found with the name {kind!r}, known classes are:\n{msg_str}" + + def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str): + opt_type = cls.get_options_type() + + if opt_type in self._classes: + raise ValueError( + f"{opt_type.kind!r} already registered to class {self._classes[opt_type]!r}" + ) + + self._classes[opt_type] = cls + self._meta[opt_type] = FactoryMeta( + kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name + ) + + def load_from_plugins( + self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False + ): + plugin_name = plugin_name or self.plugin_name + + plugin_manager = PluginManager(plugin_name) + plugin_manager.load_setuptools_entrypoints(plugin_name) + + for plugin_name, plugin_module in plugin_manager.list_name_plugin(): + plugin_module_name = str(plugin_module.__name__) # type: ignore + + if not allow_external_plugins and not plugin_module_name.startswith( + "docling." + ): + logger.warning( + f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false." + ) + continue + + attr = getattr(plugin_module, self.plugin_attr_name, None) + + if callable(attr): + logger.info("Loading plugin %r", plugin_name) + + config = attr() + self.process_plugin(config, plugin_name, plugin_module_name) + + def process_plugin(self, config, plugin_name: str, plugin_module_name: str): + for item in config[self.plugin_attr_name]: + try: + self.register(item, plugin_name, plugin_module_name) + except ValueError: + logger.warning("%r already registered", item) diff --git a/docling/models/factories/ocr_factory.py b/docling/models/factories/ocr_factory.py new file mode 100644 index 0000000..34fc7c4 --- /dev/null +++ b/docling/models/factories/ocr_factory.py @@ -0,0 +1,11 @@ +import logging + +from docling.models.base_ocr_model import BaseOcrModel +from docling.models.factories.base_factory import BaseFactory + +logger = logging.getLogger(__name__) + + +class OcrFactory(BaseFactory[BaseOcrModel]): + def __init__(self, *args, **kwargs): + super().__init__("ocr_engines", *args, **kwargs) diff --git a/docling/models/factories/picture_description_factory.py b/docling/models/factories/picture_description_factory.py new file mode 100644 index 0000000..f66d132 --- /dev/null +++ b/docling/models/factories/picture_description_factory.py @@ -0,0 +1,11 @@ +import logging + +from docling.models.factories.base_factory import BaseFactory +from docling.models.picture_description_base_model import PictureDescriptionBaseModel + +logger = logging.getLogger(__name__) + + +class PictureDescriptionFactory(BaseFactory[PictureDescriptionBaseModel]): + def __init__(self, *args, **kwargs): + super().__init__("picture_description", *args, **kwargs) diff --git a/docling/models/ocr_mac_model.py b/docling/models/ocr_mac_model.py index 9d61828..98ca3f1 100644 --- a/docling/models/ocr_mac_model.py +++ b/docling/models/ocr_mac_model.py @@ -1,13 +1,19 @@ import logging +import sys import tempfile -from typing import Iterable, Optional, Tuple +from pathlib import Path +from typing import Iterable, Optional, Tuple, Type from docling_core.types.doc import BoundingBox, CoordOrigin from docling_core.types.doc.page import BoundingRectangle, TextCell from docling.datamodel.base_models import Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import OcrMacOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + OcrMacOptions, + OcrOptions, +) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.profiling import TimeRecorder @@ -16,13 +22,26 @@ _log = logging.getLogger(__name__) class OcrMacModel(BaseOcrModel): - def __init__(self, enabled: bool, options: OcrMacOptions): - super().__init__(enabled=enabled, options=options) + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: OcrMacOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: OcrMacOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. if self.enabled: + if "darwin" != sys.platform: + raise RuntimeError(f"OcrMac is only supported on Mac.") install_errmsg = ( "ocrmac is not correctly installed. " "Please install it via `pip install ocrmac` to use this OCR engine. " @@ -121,3 +140,7 @@ class OcrMacModel(BaseOcrModel): self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return OcrMacOptions diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index c64f1bf..6ef8a7f 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -1,13 +1,18 @@ import base64 import io import logging -from typing import Iterable, List, Optional +from pathlib import Path +from typing import Iterable, List, Optional, Type, Union import requests from PIL import Image from pydantic import BaseModel, ConfigDict -from docling.datamodel.pipeline_options import PictureDescriptionApiOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + PictureDescriptionApiOptions, + PictureDescriptionBaseOptions, +) from docling.exceptions import OperationNotAllowed from docling.models.picture_description_base_model import PictureDescriptionBaseModel @@ -46,13 +51,25 @@ class ApiResponse(BaseModel): class PictureDescriptionApiModel(PictureDescriptionBaseModel): # elements_batch_size = 4 + @classmethod + def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]: + return PictureDescriptionApiOptions + def __init__( self, enabled: bool, enable_remote_services: bool, + artifacts_path: Optional[Union[Path, str]], options: PictureDescriptionApiOptions, + accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + enable_remote_services=enable_remote_services, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: PictureDescriptionApiOptions if self.enabled: diff --git a/docling/models/picture_description_base_model.py b/docling/models/picture_description_base_model.py index b653e0e..129387b 100644 --- a/docling/models/picture_description_base_model.py +++ b/docling/models/picture_description_base_model.py @@ -1,6 +1,7 @@ import logging +from abc import abstractmethod from pathlib import Path -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, List, Optional, Type, Union from docling_core.types.doc import ( DoclingDocument, @@ -13,20 +14,30 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co ) from PIL import Image -from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + PictureDescriptionBaseOptions, +) from docling.models.base_model import ( BaseItemAndImageEnrichmentModel, + BaseModelWithOptions, ItemAndImageEnrichmentElement, ) -class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): +class PictureDescriptionBaseModel( + BaseItemAndImageEnrichmentModel, BaseModelWithOptions +): images_scale: float = 2.0 def __init__( self, + *, enabled: bool, + enable_remote_services: bool, + artifacts_path: Optional[Union[Path, str]], options: PictureDescriptionBaseOptions, + accelerator_options: AcceleratorOptions, ): self.enabled = enabled self.options = options @@ -62,3 +73,8 @@ class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): PictureDescriptionData(text=output, provenance=self.provenance) ) yield item + + @classmethod + @abstractmethod + def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]: + pass diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index 69d185b..fc5c51e 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -1,10 +1,11 @@ from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Iterable, Optional, Type, Union from PIL import Image from docling.datamodel.pipeline_options import ( AcceleratorOptions, + PictureDescriptionBaseOptions, PictureDescriptionVlmOptions, ) from docling.models.picture_description_base_model import PictureDescriptionBaseModel @@ -13,14 +14,25 @@ from docling.utils.accelerator_utils import decide_device class PictureDescriptionVlmModel(PictureDescriptionBaseModel): + @classmethod + def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]: + return PictureDescriptionVlmOptions + def __init__( self, enabled: bool, + enable_remote_services: bool, artifacts_path: Optional[Union[Path, str]], options: PictureDescriptionVlmOptions, accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + enable_remote_services=enable_remote_services, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: PictureDescriptionVlmOptions if self.enabled: diff --git a/docling/models/plugins/__init__.py b/docling/models/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docling/models/plugins/defaults.py b/docling/models/plugins/defaults.py new file mode 100644 index 0000000..0087357 --- /dev/null +++ b/docling/models/plugins/defaults.py @@ -0,0 +1,28 @@ +from docling.models.easyocr_model import EasyOcrModel +from docling.models.ocr_mac_model import OcrMacModel +from docling.models.picture_description_api_model import PictureDescriptionApiModel +from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel +from docling.models.rapid_ocr_model import RapidOcrModel +from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel +from docling.models.tesseract_ocr_model import TesseractOcrModel + + +def ocr_engines(): + return { + "ocr_engines": [ + EasyOcrModel, + OcrMacModel, + RapidOcrModel, + TesseractOcrModel, + TesseractOcrCliModel, + ] + } + + +def picture_description(): + return { + "picture_description": [ + PictureDescriptionVlmModel, + PictureDescriptionApiModel, + ] + } diff --git a/docling/models/rapid_ocr_model.py b/docling/models/rapid_ocr_model.py index d1e23b3..e21974d 100644 --- a/docling/models/rapid_ocr_model.py +++ b/docling/models/rapid_ocr_model.py @@ -1,5 +1,6 @@ import logging -from typing import Iterable +from pathlib import Path +from typing import Iterable, Optional, Type import numpy from docling_core.types.doc import BoundingBox, CoordOrigin @@ -10,6 +11,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( AcceleratorDevice, AcceleratorOptions, + OcrOptions, RapidOcrOptions, ) from docling.datamodel.settings import settings @@ -24,10 +26,16 @@ class RapidOcrModel(BaseOcrModel): def __init__( self, enabled: bool, + artifacts_path: Optional[Path], options: RapidOcrOptions, accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: RapidOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -135,3 +143,7 @@ class RapidOcrModel(BaseOcrModel): self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return RapidOcrOptions diff --git a/docling/models/tesseract_ocr_cli_model.py b/docling/models/tesseract_ocr_cli_model.py index 587cecd..56968a2 100644 --- a/docling/models/tesseract_ocr_cli_model.py +++ b/docling/models/tesseract_ocr_cli_model.py @@ -3,8 +3,9 @@ import io import logging import os import tempfile +from pathlib import Path from subprocess import DEVNULL, PIPE, Popen -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Type import pandas as pd from docling_core.types.doc import BoundingBox, CoordOrigin @@ -12,7 +13,11 @@ from docling_core.types.doc.page import BoundingRectangle, TextCell from docling.datamodel.base_models import Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import TesseractCliOcrOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + OcrOptions, + TesseractCliOcrOptions, +) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.ocr_utils import map_tesseract_script @@ -22,8 +27,19 @@ _log = logging.getLogger(__name__) class TesseractOcrCliModel(BaseOcrModel): - def __init__(self, enabled: bool, options: TesseractCliOcrOptions): - super().__init__(enabled=enabled, options=options) + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: TesseractCliOcrOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: TesseractCliOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -257,3 +273,7 @@ class TesseractOcrCliModel(BaseOcrModel): self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return TesseractCliOcrOptions diff --git a/docling/models/tesseract_ocr_model.py b/docling/models/tesseract_ocr_model.py index 0000863..84a02a3 100644 --- a/docling/models/tesseract_ocr_model.py +++ b/docling/models/tesseract_ocr_model.py @@ -1,12 +1,17 @@ import logging -from typing import Iterable +from pathlib import Path +from typing import Iterable, Optional, Type from docling_core.types.doc import BoundingBox, CoordOrigin from docling_core.types.doc.page import BoundingRectangle, TextCell from docling.datamodel.base_models import Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import TesseractOcrOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + OcrOptions, + TesseractOcrOptions, +) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.ocr_utils import map_tesseract_script @@ -16,8 +21,19 @@ _log = logging.getLogger(__name__) class TesseractOcrModel(BaseOcrModel): - def __init__(self, enabled: bool, options: TesseractOcrOptions): - super().__init__(enabled=enabled, options=options) + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: TesseractOcrOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: TesseractOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -200,3 +216,7 @@ class TesseractOcrModel(BaseOcrModel): self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return TesseractOcrOptions diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index a56b84b..ecaa27c 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -10,16 +10,7 @@ from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import AssembledUnit, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import ( - EasyOcrOptions, - OcrMacOptions, - PdfPipelineOptions, - PictureDescriptionApiOptions, - PictureDescriptionVlmOptions, - RapidOcrOptions, - TesseractCliOcrOptions, - TesseractOcrOptions, -) +from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions @@ -27,22 +18,16 @@ from docling.models.document_picture_classifier import ( DocumentPictureClassifier, DocumentPictureClassifierOptions, ) -from docling.models.easyocr_model import EasyOcrModel +from docling.models.factories import get_ocr_factory, get_picture_description_factory from docling.models.layout_model import LayoutModel -from docling.models.ocr_mac_model import OcrMacModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_preprocessing_model import ( PagePreprocessingModel, PagePreprocessingOptions, ) -from docling.models.picture_description_api_model import PictureDescriptionApiModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel -from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel -from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.table_structure_model import TableStructureModel -from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel -from docling.models.tesseract_ocr_model import TesseractOcrModel from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.model_downloader import download_models from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -78,10 +63,7 @@ class StandardPdfPipeline(PaginatedPipeline): self.glm_model = ReadingOrderModel(options=ReadingOrderOptions()) - if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None: - raise RuntimeError( - f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." - ) + ocr_model = self.get_ocr_model(artifacts_path=artifacts_path) self.build_pipe = [ # Pre-processing @@ -164,66 +146,30 @@ class StandardPdfPipeline(PaginatedPipeline): output_dir = download_models(output_dir=local_dir, force=force, progress=False) return output_dir - def get_ocr_model( - self, artifacts_path: Optional[Path] = None - ) -> Optional[BaseOcrModel]: - if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions): - return EasyOcrModel( - enabled=self.pipeline_options.do_ocr, - artifacts_path=artifacts_path, - options=self.pipeline_options.ocr_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions): - return TesseractOcrCliModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - ) - elif isinstance(self.pipeline_options.ocr_options, TesseractOcrOptions): - return TesseractOcrModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - ) - elif isinstance(self.pipeline_options.ocr_options, RapidOcrOptions): - return RapidOcrModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - elif isinstance(self.pipeline_options.ocr_options, OcrMacOptions): - if "darwin" != sys.platform: - raise RuntimeError( - f"The specified OCR type is only supported on Mac: {self.pipeline_options.ocr_options.kind}." - ) - return OcrMacModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - ) - return None + def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: + factory = get_ocr_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + return factory.create_instance( + options=self.pipeline_options.ocr_options, + enabled=self.pipeline_options.do_ocr, + artifacts_path=artifacts_path, + accelerator_options=self.pipeline_options.accelerator_options, + ) def get_picture_description_model( self, artifacts_path: Optional[Path] = None ) -> Optional[PictureDescriptionBaseModel]: - if isinstance( - self.pipeline_options.picture_description_options, - PictureDescriptionApiOptions, - ): - return PictureDescriptionApiModel( - enabled=self.pipeline_options.do_picture_description, - enable_remote_services=self.pipeline_options.enable_remote_services, - options=self.pipeline_options.picture_description_options, - ) - elif isinstance( - self.pipeline_options.picture_description_options, - PictureDescriptionVlmOptions, - ): - return PictureDescriptionVlmModel( - enabled=self.pipeline_options.do_picture_description, - artifacts_path=artifacts_path, - options=self.pipeline_options.picture_description_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - return None + factory = get_picture_description_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) + return factory.create_instance( + options=self.pipeline_options.picture_description_options, + enabled=self.pipeline_options.do_picture_description, + enable_remote_services=self.pipeline_options.enable_remote_services, + artifacts_path=artifacts_path, + accelerator_options=self.pipeline_options.accelerator_options, + ) def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: with TimeRecorder(conv_res, "page_init"): diff --git a/poetry.lock b/poetry.lock index 8d803b2..f33576b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "accelerate" @@ -7838,4 +7838,4 @@ vlm = ["accelerate", "transformers", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "6917ebe61625f5b719df46c3f1597c61241b2a3b81bae640d9167d20d0182dd8" +content-hash = "a9ace62bd5b629cb2f20186b750d7c63f383f37f2e3df04cfcc821fc83c877b8" diff --git a/pyproject.toml b/pyproject.toml index bdec5df..11c80b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ accelerate = [ ] pillow = ">=10.0.0,<12.0.0" tqdm = "^4.65.0" +pluggy = "^1.0.0" pylatexenc = "^2.10" [tool.poetry.group.dev.dependencies] @@ -156,6 +157,9 @@ rapidocr = ["rapidocr-onnxruntime", "onnxruntime"] docling = "docling.cli.main:app" docling-tools = "docling.cli.tools:app" +[tool.poetry.plugins."docling"] +"docling_defaults" = "docling.models.plugins.defaults" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"