feat(SmolDocling): Support MLX acceleration in VLM pipeline (#1199)

* Initial implementation to support MLX for VLM pipeline and SmolDocling

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* mlx_model unit

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Add CLI choices for VLM pipeline and model

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Initial implementation to support MLX for VLM pipeline and SmolDocling

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* mlx_model unit

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Add CLI choices for VLM pipeline and model

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Updated minimal vlm pipeline example

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* make vlm_pipeline python3.9 compatible

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Fixed extract_text_from_backend definition

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Updated README

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Updated example

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Updated documentation

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* corrections in the documentation

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Consmetic changes

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Maxim Lysak 2025-03-19 15:38:54 +01:00 committed by GitHub
parent b454aa1551
commit 1c26769785
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 319 additions and 66 deletions

View File

@ -35,7 +35,7 @@ Docling simplifies document processing, parsing diverse formats — including ad
* 🔒 Local execution capabilities for sensitive data and air-gapped environments
* 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI
* 🔍 Extensive OCR support for scanned PDFs and images
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview))
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕
* 💻 Simple and convenient CLI
### Coming soon
@ -57,7 +57,7 @@ More [detailed installation instructions](https://docling-project.github.io/docl
## Getting started
To convert individual documents, use `convert()`, for example:
To convert individual documents with python, use `convert()`, for example:
```python
from docling.document_converter import DocumentConverter
@ -71,6 +71,22 @@ print(result.document.export_to_markdown()) # output: "## Docling Technical Rep
More [advanced usage options](https://docling-project.github.io/docling/usage/) are available in
the docs.
## CLI
Docling has a built-in CLI to run conversions.
```bash
docling https://arxiv.org/pdf/2206.01062
```
You can also use 🥚[SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview) and other VLMs via Docling CLI:
```bash
docling --pipeline vlm --vlm-model smoldocling https://arxiv.org/pdf/2206.01062
```
This will use MLX acceleration on supported Apple Silicon hardware.
Read more [here](https://docling-project.github.io/docling/usage/)
## Documentation
Check out Docling's [documentation](https://docling-project.github.io/docling/), for details on

View File

@ -32,13 +32,21 @@ from docling.datamodel.pipeline_options import (
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
PaginatedPipelineOptions,
PdfBackend,
PdfPipeline,
PdfPipelineOptions,
TableFormerMode,
VlmModelType,
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
from docling.models.factories import get_ocr_factory
from docling.pipeline.vlm_pipeline import VlmPipeline
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
@ -200,6 +208,14 @@ def convert(
help="Image export mode for the document (only in case of JSON, Markdown or HTML). With `placeholder`, only the position of the image is marked in the output. In `embedded` mode, the image is embedded as base64 encoded string. In `referenced` mode, the image is exported in PNG format and referenced from the main exported document.",
),
] = ImageRefMode.EMBEDDED,
pipeline: Annotated[
PdfPipeline,
typer.Option(..., help="Choose the pipeline to process PDF or image files."),
] = PdfPipeline.STANDARD,
vlm_model: Annotated[
VlmModelType,
typer.Option(..., help="Choose the VLM model to use with PDF or image files."),
] = VlmModelType.SMOLDOCLING,
ocr: Annotated[
bool,
typer.Option(
@ -420,50 +436,77 @@ def convert(
ocr_options.lang = ocr_lang_list
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
pipeline_options = PdfPipelineOptions(
allow_external_plugins=allow_external_plugins,
enable_remote_services=enable_remote_services,
accelerator_options=accelerator_options,
do_ocr=ocr,
ocr_options=ocr_options,
do_table_structure=True,
do_code_enrichment=enrich_code,
do_formula_enrichment=enrich_formula,
do_picture_description=enrich_picture_description,
do_picture_classification=enrich_picture_classes,
document_timeout=document_timeout,
)
pipeline_options.table_structure_options.do_cell_matching = (
True # do_cell_matching
)
pipeline_options.table_structure_options.mode = table_mode
pipeline_options: PaginatedPipelineOptions
if image_export_mode != ImageRefMode.PLACEHOLDER:
pipeline_options.generate_page_images = True
pipeline_options.generate_picture_images = (
True # FIXME: to be deprecated in verson 3
if pipeline == PdfPipeline.STANDARD:
pipeline_options = PdfPipelineOptions(
allow_external_plugins=allow_external_plugins,
enable_remote_services=enable_remote_services,
accelerator_options=accelerator_options,
do_ocr=ocr,
ocr_options=ocr_options,
do_table_structure=True,
do_code_enrichment=enrich_code,
do_formula_enrichment=enrich_formula,
do_picture_description=enrich_picture_description,
do_picture_classification=enrich_picture_classes,
document_timeout=document_timeout,
)
pipeline_options.table_structure_options.do_cell_matching = (
True # do_cell_matching
)
pipeline_options.table_structure_options.mode = table_mode
if image_export_mode != ImageRefMode.PLACEHOLDER:
pipeline_options.generate_page_images = True
pipeline_options.generate_picture_images = (
True # FIXME: to be deprecated in verson 3
)
pipeline_options.images_scale = 2
backend: Type[PdfDocumentBackend]
if pdf_backend == PdfBackend.DLPARSE_V1:
backend = DoclingParseDocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V2:
backend = DoclingParseV2DocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V4:
backend = DoclingParseV4DocumentBackend # type: ignore
elif pdf_backend == PdfBackend.PYPDFIUM2:
backend = PyPdfiumDocumentBackend # type: ignore
else:
raise RuntimeError(f"Unexpected PDF backend type {pdf_backend}")
pdf_format_option = PdfFormatOption(
pipeline_options=pipeline_options,
backend=backend, # pdf_backend
)
elif pipeline == PdfPipeline.VLM:
pipeline_options = VlmPipelineOptions()
if vlm_model == VlmModelType.GRANITE_VISION:
pipeline_options.vlm_options = granite_vision_vlm_conversion_options
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
if sys.platform == "darwin":
try:
import mlx_vlm
pipeline_options.vlm_options = (
smoldocling_vlm_mlx_conversion_options
)
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
)
pipeline_options.images_scale = 2
if artifacts_path is not None:
pipeline_options.artifacts_path = artifacts_path
backend: Type[PdfDocumentBackend]
if pdf_backend == PdfBackend.DLPARSE_V1:
backend = DoclingParseDocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V2:
backend = DoclingParseV2DocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V4:
backend = DoclingParseV4DocumentBackend # type: ignore
elif pdf_backend == PdfBackend.PYPDFIUM2:
backend = PyPdfiumDocumentBackend # type: ignore
else:
raise RuntimeError(f"Unexpected PDF backend type {pdf_backend}")
pdf_format_option = PdfFormatOption(
pipeline_options=pipeline_options,
backend=backend, # pdf_backend
)
format_options: Dict[InputFormat, FormatOption] = {
InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option,

View File

@ -263,6 +263,11 @@ class ResponseFormat(str, Enum):
MARKDOWN = "markdown"
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"
@ -271,6 +276,7 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
response_format: ResponseFormat
@property
@ -278,10 +284,19 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
return self.repo_id.replace("/", "--")
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
)
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
)
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
@ -289,9 +304,15 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
# prompt="OCR the full page to markdown.",
prompt="OCR this image.",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
GRANITE_VISION = "granite_vision"
# Define an enum for the backend options
class PdfBackend(str, Enum):
"""Enum of valid PDF backends."""
@ -327,13 +348,14 @@ class PipelineOptions(BaseModel):
class PaginatedPipelineOptions(PipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
images_scale: float = 1.0
generate_page_images: bool = False
generate_picture_images: bool = False
class VlmPipelineOptions(PaginatedPipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
generate_page_images: bool = True
force_backend_text: bool = (
@ -346,7 +368,6 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
class PdfPipelineOptions(PaginatedPipelineOptions):
"""Options for the PDF pipeline."""
artifacts_path: Optional[Union[Path, str]] = None
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_code_enrichment: bool = False # True: perform code OCR
@ -377,3 +398,8 @@ class PdfPipelineOptions(PaginatedPipelineOptions):
)
generate_parsed_pages: bool = False
class PdfPipeline(str, Enum):
STANDARD = "standard"
VLM = "vlm"

View File

@ -0,0 +1,137 @@
import logging
import time
from pathlib import Path
from typing import Iterable, List, Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceMlxModel(BasePageModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
try:
from mlx_vlm import generate, load # type: ignore
from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
from mlx_vlm.utils import load_config, stream_generate # type: ignore
except ImportError:
raise ImportError(
"mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
)
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
self.apply_chat_template = apply_chat_template
self.stream_generate = stream_generate
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
## Load the model
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
prompt = self.apply_chat_template(
self.processor, self.config, self.param_question, num_images=1
)
start_time = time.time()
# Call model to generate:
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=4096,
verbose=False,
):
output += token.text
if "</doctag>" in token.text:
break
generation_time = time.time() - start_time
page_tags = output
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -14,8 +14,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import ResponseFormat, VlmPipelineOptions
from docling.datamodel.pipeline_options import (
InferenceFramework,
ResponseFormat,
VlmPipelineOptions,
)
from docling.datamodel.settings import settings
from docling.models.hf_mlx_model import HuggingFaceMlxModel
from docling.models.hf_vlm_model import HuggingFaceVlmModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
@ -29,12 +34,6 @@ class VlmPipeline(PaginatedPipeline):
super().__init__(pipeline_options)
self.keep_backend = True
warnings.warn(
"The VlmPipeline is currently experimental and may change in upcoming versions without notice.",
category=UserWarning,
stacklevel=2,
)
self.pipeline_options: VlmPipelineOptions
artifacts_path: Optional[Path] = None
@ -58,14 +57,27 @@ class VlmPipeline(PaginatedPipeline):
self.keep_images = self.pipeline_options.generate_page_images
self.build_pipe = [
HuggingFaceVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
),
]
if (
self.pipeline_options.vlm_options.inference_framework
== InferenceFramework.MLX
):
self.build_pipe = [
HuggingFaceMlxModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
),
]
else:
self.build_pipe = [
HuggingFaceVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
),
]
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
@ -79,7 +91,9 @@ class VlmPipeline(PaginatedPipeline):
return page
def extract_text_from_backend(self, page: Page, bbox: BoundingBox | None) -> str:
def extract_text_from_backend(
self, page: Page, bbox: Union[BoundingBox, None]
) -> str:
# Convert bounding box normalized to 0-100 into page coordinates for cropping
text = ""
if bbox:

View File

@ -10,13 +10,15 @@ from docling.datamodel.pipeline_options import (
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
sources = [
"tests/data/2305.03393v1-pg9-img.png",
# "tests/data/2305.03393v1-pg9-img.png",
"tests/data/pdf/2305.03393v1-pg9.pdf",
]
## Use experimental VlmPipeline
@ -29,7 +31,10 @@ pipeline_options.force_backend_text = False
# pipeline_options.accelerator_options.cuda_use_flash_attention2 = True
## Pick a VLM model. We choose SmolDocling-256M by default
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX
pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
## Alternative VLM models:
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
@ -63,9 +68,6 @@ for source in sources:
res = converter.convert(source)
print("------------------------------------------------")
print("MD:")
print("------------------------------------------------")
print("")
print(res.document.export_to_markdown())
@ -83,8 +85,17 @@ for source in sources:
with (out_path / f"{res.input.file.stem}.json").open("w") as fp:
fp.write(json.dumps(res.document.export_to_dict()))
pg_num = res.document.num_pages()
res.document.save_as_json(
out_path / f"{res.input.file.stem}.md",
image_mode=ImageRefMode.PLACEHOLDER,
)
res.document.save_as_markdown(
out_path / f"{res.input.file.stem}.md",
image_mode=ImageRefMode.PLACEHOLDER,
)
pg_num = res.document.num_pages()
print("")
inference_time = time.time() - start_time
print(

View File

@ -26,7 +26,7 @@ Docling simplifies document processing, parsing diverse formats — including ad
* 🔒 Local execution capabilities for sensitive data and air-gapped environments
* 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI
* 🔍 Extensive OCR support for scanned PDFs and images
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview))
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕🔥
* 💻 Simple and convenient CLI
### Coming soon

View File

@ -17,10 +17,15 @@ print(result.document.export_to_markdown()) # output: "### Docling Technical Re
You can also use Docling directly from your command line to convert individual files —be it local or by URL— or whole directories.
A simple example would look like this:
```console
docling https://arxiv.org/pdf/2206.01062
```
You can also use 🥚[SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview) and other VLMs via Docling CLI:
```bash
docling --pipeline vlm --vlm-model smoldocling https://arxiv.org/pdf/2206.01062
```
This will use MLX acceleration on supported Apple Silicon hardware.
To see all available options (export formats etc.) run `docling --help`. More details in the [CLI reference page](../reference/cli.md).

View File

@ -192,6 +192,7 @@ module = [
"docling_ibm_models.*",
"easyocr.*",
"ocrmac.*",
"mlx_vlm.*",
"lxml.*",
"huggingface_hub.*",
"transformers.*",