From 1c26769785bcd17c0b8b621c5182ad81134d3915 Mon Sep 17 00:00:00 2001 From: Maxim Lysak <101627549+maxmnemonic@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:38:54 +0100 Subject: [PATCH] feat(SmolDocling): Support MLX acceleration in VLM pipeline (#1199) * Initial implementation to support MLX for VLM pipeline and SmolDocling Signed-off-by: Maksym Lysak * mlx_model unit Signed-off-by: Maksym Lysak * Add CLI choices for VLM pipeline and model Signed-off-by: Christoph Auer * Initial implementation to support MLX for VLM pipeline and SmolDocling Signed-off-by: Maksym Lysak * mlx_model unit Signed-off-by: Maksym Lysak * Add CLI choices for VLM pipeline and model Signed-off-by: Christoph Auer * Updated minimal vlm pipeline example Signed-off-by: Maksym Lysak * make vlm_pipeline python3.9 compatible Signed-off-by: Maksym Lysak * Fixed extract_text_from_backend definition Signed-off-by: Maksym Lysak * Updated README Signed-off-by: Maksym Lysak * Updated example Signed-off-by: Maksym Lysak * Updated documentation Signed-off-by: Maksym Lysak * corrections in the documentation Signed-off-by: Maksym Lysak * Consmetic changes Signed-off-by: Christoph Auer --------- Signed-off-by: Maksym Lysak Signed-off-by: Christoph Auer Co-authored-by: Maksym Lysak Co-authored-by: Christoph Auer --- README.md | 20 +++- docling/cli/main.py | 119 +++++++++++++++------- docling/datamodel/pipeline_options.py | 30 +++++- docling/models/hf_mlx_model.py | 137 ++++++++++++++++++++++++++ docling/pipeline/vlm_pipeline.py | 46 ++++++--- docs/examples/minimal_vlm_pipeline.py | 23 +++-- docs/index.md | 2 +- docs/usage/index.md | 7 +- pyproject.toml | 1 + 9 files changed, 319 insertions(+), 66 deletions(-) create mode 100644 docling/models/hf_mlx_model.py diff --git a/README.md b/README.md index 208de0d..19048b7 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Docling simplifies document processing, parsing diverse formats — including ad * 🔒 Local execution capabilities for sensitive data and air-gapped environments * 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI * 🔍 Extensive OCR support for scanned PDFs and images -* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) +* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕 * 💻 Simple and convenient CLI ### Coming soon @@ -57,7 +57,7 @@ More [detailed installation instructions](https://docling-project.github.io/docl ## Getting started -To convert individual documents, use `convert()`, for example: +To convert individual documents with python, use `convert()`, for example: ```python from docling.document_converter import DocumentConverter @@ -71,6 +71,22 @@ print(result.document.export_to_markdown()) # output: "## Docling Technical Rep More [advanced usage options](https://docling-project.github.io/docling/usage/) are available in the docs. +## CLI + +Docling has a built-in CLI to run conversions. + +```bash +docling https://arxiv.org/pdf/2206.01062 +``` + +You can also use 🥚[SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview) and other VLMs via Docling CLI: +```bash +docling --pipeline vlm --vlm-model smoldocling https://arxiv.org/pdf/2206.01062 +``` +This will use MLX acceleration on supported Apple Silicon hardware. + +Read more [here](https://docling-project.github.io/docling/usage/) + ## Documentation Check out Docling's [documentation](https://docling-project.github.io/docling/), for details on diff --git a/docling/cli/main.py b/docling/cli/main.py index 7f0f20b..c85a04f 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -32,13 +32,21 @@ from docling.datamodel.pipeline_options import ( AcceleratorOptions, EasyOcrOptions, OcrOptions, + PaginatedPipelineOptions, PdfBackend, + PdfPipeline, PdfPipelineOptions, TableFormerMode, + VlmModelType, + VlmPipelineOptions, + granite_vision_vlm_conversion_options, + smoldocling_vlm_conversion_options, + smoldocling_vlm_mlx_conversion_options, ) from docling.datamodel.settings import settings from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption from docling.models.factories import get_ocr_factory +from docling.pipeline.vlm_pipeline import VlmPipeline warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch") warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr") @@ -200,6 +208,14 @@ def convert( help="Image export mode for the document (only in case of JSON, Markdown or HTML). With `placeholder`, only the position of the image is marked in the output. In `embedded` mode, the image is embedded as base64 encoded string. In `referenced` mode, the image is exported in PNG format and referenced from the main exported document.", ), ] = ImageRefMode.EMBEDDED, + pipeline: Annotated[ + PdfPipeline, + typer.Option(..., help="Choose the pipeline to process PDF or image files."), + ] = PdfPipeline.STANDARD, + vlm_model: Annotated[ + VlmModelType, + typer.Option(..., help="Choose the VLM model to use with PDF or image files."), + ] = VlmModelType.SMOLDOCLING, ocr: Annotated[ bool, typer.Option( @@ -420,50 +436,77 @@ def convert( ocr_options.lang = ocr_lang_list accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) - pipeline_options = PdfPipelineOptions( - allow_external_plugins=allow_external_plugins, - enable_remote_services=enable_remote_services, - accelerator_options=accelerator_options, - do_ocr=ocr, - ocr_options=ocr_options, - do_table_structure=True, - do_code_enrichment=enrich_code, - do_formula_enrichment=enrich_formula, - do_picture_description=enrich_picture_description, - do_picture_classification=enrich_picture_classes, - document_timeout=document_timeout, - ) - pipeline_options.table_structure_options.do_cell_matching = ( - True # do_cell_matching - ) - pipeline_options.table_structure_options.mode = table_mode + pipeline_options: PaginatedPipelineOptions - if image_export_mode != ImageRefMode.PLACEHOLDER: - pipeline_options.generate_page_images = True - pipeline_options.generate_picture_images = ( - True # FIXME: to be deprecated in verson 3 + if pipeline == PdfPipeline.STANDARD: + pipeline_options = PdfPipelineOptions( + allow_external_plugins=allow_external_plugins, + enable_remote_services=enable_remote_services, + accelerator_options=accelerator_options, + do_ocr=ocr, + ocr_options=ocr_options, + do_table_structure=True, + do_code_enrichment=enrich_code, + do_formula_enrichment=enrich_formula, + do_picture_description=enrich_picture_description, + do_picture_classification=enrich_picture_classes, + document_timeout=document_timeout, + ) + pipeline_options.table_structure_options.do_cell_matching = ( + True # do_cell_matching + ) + pipeline_options.table_structure_options.mode = table_mode + + if image_export_mode != ImageRefMode.PLACEHOLDER: + pipeline_options.generate_page_images = True + pipeline_options.generate_picture_images = ( + True # FIXME: to be deprecated in verson 3 + ) + pipeline_options.images_scale = 2 + + backend: Type[PdfDocumentBackend] + if pdf_backend == PdfBackend.DLPARSE_V1: + backend = DoclingParseDocumentBackend + elif pdf_backend == PdfBackend.DLPARSE_V2: + backend = DoclingParseV2DocumentBackend + elif pdf_backend == PdfBackend.DLPARSE_V4: + backend = DoclingParseV4DocumentBackend # type: ignore + elif pdf_backend == PdfBackend.PYPDFIUM2: + backend = PyPdfiumDocumentBackend # type: ignore + else: + raise RuntimeError(f"Unexpected PDF backend type {pdf_backend}") + + pdf_format_option = PdfFormatOption( + pipeline_options=pipeline_options, + backend=backend, # pdf_backend + ) + elif pipeline == PdfPipeline.VLM: + pipeline_options = VlmPipelineOptions() + + if vlm_model == VlmModelType.GRANITE_VISION: + pipeline_options.vlm_options = granite_vision_vlm_conversion_options + elif vlm_model == VlmModelType.SMOLDOCLING: + pipeline_options.vlm_options = smoldocling_vlm_conversion_options + if sys.platform == "darwin": + try: + import mlx_vlm + + pipeline_options.vlm_options = ( + smoldocling_vlm_mlx_conversion_options + ) + except ImportError: + _log.warning( + "To run SmolDocling faster, please install mlx-vlm:\n" + "pip install mlx-vlm" + ) + + pdf_format_option = PdfFormatOption( + pipeline_cls=VlmPipeline, pipeline_options=pipeline_options ) - pipeline_options.images_scale = 2 if artifacts_path is not None: pipeline_options.artifacts_path = artifacts_path - backend: Type[PdfDocumentBackend] - if pdf_backend == PdfBackend.DLPARSE_V1: - backend = DoclingParseDocumentBackend - elif pdf_backend == PdfBackend.DLPARSE_V2: - backend = DoclingParseV2DocumentBackend - elif pdf_backend == PdfBackend.DLPARSE_V4: - backend = DoclingParseV4DocumentBackend # type: ignore - elif pdf_backend == PdfBackend.PYPDFIUM2: - backend = PyPdfiumDocumentBackend # type: ignore - else: - raise RuntimeError(f"Unexpected PDF backend type {pdf_backend}") - - pdf_format_option = PdfFormatOption( - pipeline_options=pipeline_options, - backend=backend, # pdf_backend - ) format_options: Dict[InputFormat, FormatOption] = { InputFormat.PDF: pdf_format_option, InputFormat.IMAGE: pdf_format_option, diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index d28b582..654e04d 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -263,6 +263,11 @@ class ResponseFormat(str, Enum): MARKDOWN = "markdown" +class InferenceFramework(str, Enum): + MLX = "mlx" + TRANSFORMERS = "transformers" + + class HuggingFaceVlmOptions(BaseVlmOptions): kind: Literal["hf_model_options"] = "hf_model_options" @@ -271,6 +276,7 @@ class HuggingFaceVlmOptions(BaseVlmOptions): llm_int8_threshold: float = 6.0 quantized: bool = False + inference_framework: InferenceFramework response_format: ResponseFormat @property @@ -278,10 +284,19 @@ class HuggingFaceVlmOptions(BaseVlmOptions): return self.repo_id.replace("/", "--") +smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions( + repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16", + prompt="Convert this page to docling.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.MLX, +) + + smoldocling_vlm_conversion_options = HuggingFaceVlmOptions( repo_id="ds4sd/SmolDocling-256M-preview", prompt="Convert this page to docling.", response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.TRANSFORMERS, ) granite_vision_vlm_conversion_options = HuggingFaceVlmOptions( @@ -289,9 +304,15 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions( # prompt="OCR the full page to markdown.", prompt="OCR this image.", response_format=ResponseFormat.MARKDOWN, + inference_framework=InferenceFramework.TRANSFORMERS, ) +class VlmModelType(str, Enum): + SMOLDOCLING = "smoldocling" + GRANITE_VISION = "granite_vision" + + # Define an enum for the backend options class PdfBackend(str, Enum): """Enum of valid PDF backends.""" @@ -327,13 +348,14 @@ class PipelineOptions(BaseModel): class PaginatedPipelineOptions(PipelineOptions): + artifacts_path: Optional[Union[Path, str]] = None + images_scale: float = 1.0 generate_page_images: bool = False generate_picture_images: bool = False class VlmPipelineOptions(PaginatedPipelineOptions): - artifacts_path: Optional[Union[Path, str]] = None generate_page_images: bool = True force_backend_text: bool = ( @@ -346,7 +368,6 @@ class VlmPipelineOptions(PaginatedPipelineOptions): class PdfPipelineOptions(PaginatedPipelineOptions): """Options for the PDF pipeline.""" - artifacts_path: Optional[Union[Path, str]] = None do_table_structure: bool = True # True: perform table structure extraction do_ocr: bool = True # True: perform OCR, replace programmatic PDF text do_code_enrichment: bool = False # True: perform code OCR @@ -377,3 +398,8 @@ class PdfPipelineOptions(PaginatedPipelineOptions): ) generate_parsed_pages: bool = False + + +class PdfPipeline(str, Enum): + STANDARD = "standard" + VLM = "vlm" diff --git a/docling/models/hf_mlx_model.py b/docling/models/hf_mlx_model.py new file mode 100644 index 0000000..762a655 --- /dev/null +++ b/docling/models/hf_mlx_model.py @@ -0,0 +1,137 @@ +import logging +import time +from pathlib import Path +from typing import Iterable, List, Optional + +from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import ( + AcceleratorDevice, + AcceleratorOptions, + HuggingFaceVlmOptions, +) +from docling.datamodel.settings import settings +from docling.models.base_model import BasePageModel +from docling.utils.accelerator_utils import decide_device +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class HuggingFaceMlxModel(BasePageModel): + + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + accelerator_options: AcceleratorOptions, + vlm_options: HuggingFaceVlmOptions, + ): + self.enabled = enabled + + self.vlm_options = vlm_options + + if self.enabled: + + try: + from mlx_vlm import generate, load # type: ignore + from mlx_vlm.prompt_utils import apply_chat_template # type: ignore + from mlx_vlm.utils import load_config, stream_generate # type: ignore + except ImportError: + raise ImportError( + "mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models." + ) + + repo_cache_folder = vlm_options.repo_id.replace("/", "--") + self.apply_chat_template = apply_chat_template + self.stream_generate = stream_generate + + # PARAMETERS: + if artifacts_path is None: + artifacts_path = self.download_models(self.vlm_options.repo_id) + elif (artifacts_path / repo_cache_folder).exists(): + artifacts_path = artifacts_path / repo_cache_folder + + self.param_question = vlm_options.prompt # "Perform Layout Analysis." + + ## Load the model + self.vlm_model, self.processor = load(artifacts_path) + self.config = load_config(artifacts_path) + + @staticmethod + def download_models( + repo_id: str, + local_dir: Optional[Path] = None, + force: bool = False, + progress: bool = False, + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + if not progress: + disable_progress_bars() + download_path = snapshot_download( + repo_id=repo_id, + force_download=force, + local_dir=local_dir, + # revision="v0.0.1", + ) + + return Path(download_path) + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + for page in page_batch: + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "vlm"): + assert page.size is not None + + hi_res_image = page.get_image(scale=2.0) # 144dpi + # hi_res_image = page.get_image(scale=1.0) # 72dpi + + if hi_res_image is not None: + im_width, im_height = hi_res_image.size + + # populate page_tags with predicted doc tags + page_tags = "" + + if hi_res_image: + if hi_res_image.mode != "RGB": + hi_res_image = hi_res_image.convert("RGB") + + prompt = self.apply_chat_template( + self.processor, self.config, self.param_question, num_images=1 + ) + + start_time = time.time() + # Call model to generate: + output = "" + for token in self.stream_generate( + self.vlm_model, + self.processor, + prompt, + [hi_res_image], + max_tokens=4096, + verbose=False, + ): + output += token.text + if "" in token.text: + break + + generation_time = time.time() - start_time + page_tags = output + + # inference_time = time.time() - start_time + # tokens_per_second = num_tokens / generation_time + # print("") + # print(f"Page Inference Time: {inference_time:.2f} seconds") + # print(f"Total tokens on page: {num_tokens:.2f}") + # print(f"Tokens/sec: {tokens_per_second:.2f}") + # print("") + page.predictions.vlm_response = VlmPrediction(text=page_tags) + + yield page diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 4afb918..d4defa8 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -14,8 +14,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import InputFormat, Page from docling.datamodel.document import ConversionResult, InputDocument -from docling.datamodel.pipeline_options import ResponseFormat, VlmPipelineOptions +from docling.datamodel.pipeline_options import ( + InferenceFramework, + ResponseFormat, + VlmPipelineOptions, +) from docling.datamodel.settings import settings +from docling.models.hf_mlx_model import HuggingFaceMlxModel from docling.models.hf_vlm_model import HuggingFaceVlmModel from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -29,12 +34,6 @@ class VlmPipeline(PaginatedPipeline): super().__init__(pipeline_options) self.keep_backend = True - warnings.warn( - "The VlmPipeline is currently experimental and may change in upcoming versions without notice.", - category=UserWarning, - stacklevel=2, - ) - self.pipeline_options: VlmPipelineOptions artifacts_path: Optional[Path] = None @@ -58,14 +57,27 @@ class VlmPipeline(PaginatedPipeline): self.keep_images = self.pipeline_options.generate_page_images - self.build_pipe = [ - HuggingFaceVlmModel( - enabled=True, # must be always enabled for this pipeline to make sense. - artifacts_path=artifacts_path, - accelerator_options=pipeline_options.accelerator_options, - vlm_options=self.pipeline_options.vlm_options, - ), - ] + if ( + self.pipeline_options.vlm_options.inference_framework + == InferenceFramework.MLX + ): + self.build_pipe = [ + HuggingFaceMlxModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=self.pipeline_options.vlm_options, + ), + ] + else: + self.build_pipe = [ + HuggingFaceVlmModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=self.pipeline_options.vlm_options, + ), + ] self.enrichment_pipe = [ # Other models working on `NodeItem` elements in the DoclingDocument @@ -79,7 +91,9 @@ class VlmPipeline(PaginatedPipeline): return page - def extract_text_from_backend(self, page: Page, bbox: BoundingBox | None) -> str: + def extract_text_from_backend( + self, page: Page, bbox: Union[BoundingBox, None] + ) -> str: # Convert bounding box normalized to 0-100 into page coordinates for cropping text = "" if bbox: diff --git a/docs/examples/minimal_vlm_pipeline.py b/docs/examples/minimal_vlm_pipeline.py index 948ecc6..6a15fe4 100644 --- a/docs/examples/minimal_vlm_pipeline.py +++ b/docs/examples/minimal_vlm_pipeline.py @@ -10,13 +10,15 @@ from docling.datamodel.pipeline_options import ( VlmPipelineOptions, granite_vision_vlm_conversion_options, smoldocling_vlm_conversion_options, + smoldocling_vlm_mlx_conversion_options, ) from docling.datamodel.settings import settings from docling.document_converter import DocumentConverter, PdfFormatOption from docling.pipeline.vlm_pipeline import VlmPipeline sources = [ - "tests/data/2305.03393v1-pg9-img.png", + # "tests/data/2305.03393v1-pg9-img.png", + "tests/data/pdf/2305.03393v1-pg9.pdf", ] ## Use experimental VlmPipeline @@ -29,7 +31,10 @@ pipeline_options.force_backend_text = False # pipeline_options.accelerator_options.cuda_use_flash_attention2 = True ## Pick a VLM model. We choose SmolDocling-256M by default -pipeline_options.vlm_options = smoldocling_vlm_conversion_options +# pipeline_options.vlm_options = smoldocling_vlm_conversion_options + +## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX +pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options ## Alternative VLM models: # pipeline_options.vlm_options = granite_vision_vlm_conversion_options @@ -63,9 +68,6 @@ for source in sources: res = converter.convert(source) - print("------------------------------------------------") - print("MD:") - print("------------------------------------------------") print("") print(res.document.export_to_markdown()) @@ -83,8 +85,17 @@ for source in sources: with (out_path / f"{res.input.file.stem}.json").open("w") as fp: fp.write(json.dumps(res.document.export_to_dict())) - pg_num = res.document.num_pages() + res.document.save_as_json( + out_path / f"{res.input.file.stem}.md", + image_mode=ImageRefMode.PLACEHOLDER, + ) + res.document.save_as_markdown( + out_path / f"{res.input.file.stem}.md", + image_mode=ImageRefMode.PLACEHOLDER, + ) + + pg_num = res.document.num_pages() print("") inference_time = time.time() - start_time print( diff --git a/docs/index.md b/docs/index.md index 789dae8..acc9933 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,7 +26,7 @@ Docling simplifies document processing, parsing diverse formats — including ad * 🔒 Local execution capabilities for sensitive data and air-gapped environments * 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI * 🔍 Extensive OCR support for scanned PDFs and images -* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) +* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕🔥 * 💻 Simple and convenient CLI ### Coming soon diff --git a/docs/usage/index.md b/docs/usage/index.md index 1ab7842..acf3397 100644 --- a/docs/usage/index.md +++ b/docs/usage/index.md @@ -17,10 +17,15 @@ print(result.document.export_to_markdown()) # output: "### Docling Technical Re You can also use Docling directly from your command line to convert individual files —be it local or by URL— or whole directories. -A simple example would look like this: ```console docling https://arxiv.org/pdf/2206.01062 ``` +You can also use 🥚[SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview) and other VLMs via Docling CLI: +```bash +docling --pipeline vlm --vlm-model smoldocling https://arxiv.org/pdf/2206.01062 +``` +This will use MLX acceleration on supported Apple Silicon hardware. + To see all available options (export formats etc.) run `docling --help`. More details in the [CLI reference page](../reference/cli.md). diff --git a/pyproject.toml b/pyproject.toml index 0f85915..8d121d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,6 +192,7 @@ module = [ "docling_ibm_models.*", "easyocr.*", "ocrmac.*", + "mlx_vlm.*", "lxml.*", "huggingface_hub.*", "transformers.*",