From c605edd8e91d988f6dca2bdfc67c54d6396fe903 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Apr 2025 10:03:04 -0600 Subject: [PATCH] feat: OllamaVlmModel for Granite Vision 3.2 (#1337) * build: Add ollama sdk dependency Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * feat: Add option plumbing for OllamaVlmOptions in pipeline_options Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * feat: Full implementation of OllamaVlmModel Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * feat: Connect "granite_vision_ollama" pipeline option to CLI Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * Revert "build: Add ollama sdk dependency" After consideration, we're going to use the generic OpenAI API instead of the Ollama-specific API to avoid duplicate work. This reverts commit bc6b366468cdd66b52540aac9c7d8b584ab48ad0. Signed-off-by: Gabe Goodhart * refactor: Move OpenAI API call logic into utils.utils This will allow reuse of this logic in a generic VLM model NOTE: There is a subtle change here in the ordering of the text prompt and the image in the call to the OpenAI API. When run against Ollama, this ordering makes a big difference. If the prompt comes before the image, the result is terse and not usable whereas the prompt coming after the image works as expected and matches the non-OpenAI chat API. Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * refactor: Refactor from Ollama SDK to generic OpenAI API Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * fix: Linting, formatting, and bug fixes The one bug fix was in the timeout arg to openai_image_request. Otherwise, this is all style changes to get MyPy and black passing cleanly. Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart * remove model from download enum Signed-off-by: Michele Dolfi * generalize input args for other API providers Signed-off-by: Michele Dolfi * rename and refactor Signed-off-by: Michele Dolfi * add example Signed-off-by: Michele Dolfi * require flag for remote services Signed-off-by: Michele Dolfi * disable example from CI Signed-off-by: Michele Dolfi * add examples to docs Signed-off-by: Michele Dolfi --------- Signed-off-by: Gabe Goodhart Signed-off-by: Michele Dolfi Co-authored-by: Michele Dolfi --- .github/workflows/checks.yml | 2 +- docling/cli/main.py | 9 +- docling/datamodel/base_models.py | 32 +++++ docling/datamodel/pipeline_options.py | 28 ++++- docling/models/api_vlm_model.py | 67 +++++++++++ .../models/picture_description_api_model.py | 83 ++----------- docling/pipeline/vlm_pipeline.py | 44 ++++--- docling/utils/api_image_request.py | 61 ++++++++++ docs/examples/vlm_pipeline_api_model.py | 111 ++++++++++++++++++ mkdocs.yml | 2 + 10 files changed, 344 insertions(+), 95 deletions(-) create mode 100644 docling/models/api_vlm_model.py create mode 100644 docling/utils/api_image_request.py create mode 100644 docs/examples/vlm_pipeline_api_model.py diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index b2a295d..ee5ba79 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -37,7 +37,7 @@ jobs: run: | for file in docs/examples/*.py; do # Skip batch_convert.py - if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then + if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then echo "Skipping $file" continue fi diff --git a/docling/cli/main.py b/docling/cli/main.py index e0f0cbd..c87e311 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -40,6 +40,7 @@ from docling.datamodel.pipeline_options import ( VlmModelType, VlmPipelineOptions, granite_vision_vlm_conversion_options, + granite_vision_vlm_ollama_conversion_options, smoldocling_vlm_conversion_options, smoldocling_vlm_mlx_conversion_options, ) @@ -531,10 +532,16 @@ def convert( backend=backend, # pdf_backend ) elif pipeline == PdfPipeline.VLM: - pipeline_options = VlmPipelineOptions() + pipeline_options = VlmPipelineOptions( + enable_remote_services=enable_remote_services, + ) if vlm_model == VlmModelType.GRANITE_VISION: pipeline_options.vlm_options = granite_vision_vlm_conversion_options + elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA: + pipeline_options.vlm_options = ( + granite_vision_vlm_ollama_conversion_options + ) elif vlm_model == VlmModelType.SMOLDOCLING: pipeline_options.vlm_options = smoldocling_vlm_conversion_options if sys.platform == "darwin": diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 76827a1..7dcf89c 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -262,3 +262,35 @@ class Page(BaseModel): @property def image(self) -> Optional[Image]: return self.get_image(scale=self._default_image_scale) + + +## OpenAI API Request / Response Models ## + + +class OpenAiChatMessage(BaseModel): + role: str + content: str + + +class OpenAiResponseChoice(BaseModel): + index: int + message: OpenAiChatMessage + finish_reason: str + + +class OpenAiResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class OpenAiApiResponse(BaseModel): + model_config = ConfigDict( + protected_namespaces=(), + ) + + id: str + model: Optional[str] = None # returned by openai + choices: List[OpenAiResponseChoice] + created: int + usage: OpenAiResponseUsage diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 654e04d..9791a25 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -266,6 +266,7 @@ class ResponseFormat(str, Enum): class InferenceFramework(str, Enum): MLX = "mlx" TRANSFORMERS = "transformers" + OPENAI = "openai" class HuggingFaceVlmOptions(BaseVlmOptions): @@ -284,6 +285,19 @@ class HuggingFaceVlmOptions(BaseVlmOptions): return self.repo_id.replace("/", "--") +class ApiVlmOptions(BaseVlmOptions): + kind: Literal["api_model_options"] = "api_model_options" + + url: AnyUrl = AnyUrl( + "http://localhost:11434/v1/chat/completions" + ) # Default to ollama + headers: Dict[str, str] = {} + params: Dict[str, Any] = {} + scale: float = 2.0 + timeout: float = 60 + response_format: ResponseFormat + + smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions( repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16", prompt="Convert this page to docling.", @@ -307,10 +321,20 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions( inference_framework=InferenceFramework.TRANSFORMERS, ) +granite_vision_vlm_ollama_conversion_options = ApiVlmOptions( + url=AnyUrl("http://localhost:11434/v1/chat/completions"), + params={"model": "granite3.2-vision:2b"}, + prompt="OCR the full page to markdown.", + scale=1.0, + timeout=120, + response_format=ResponseFormat.MARKDOWN, +) + class VlmModelType(str, Enum): SMOLDOCLING = "smoldocling" GRANITE_VISION = "granite_vision" + GRANITE_VISION_OLLAMA = "granite_vision_ollama" # Define an enum for the backend options @@ -362,7 +386,9 @@ class VlmPipelineOptions(PaginatedPipelineOptions): False # (To be used with vlms, or other generative models) ) # If True, text from backend will be used instead of generated text - vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options + vlm_options: Union[HuggingFaceVlmOptions, ApiVlmOptions] = ( + smoldocling_vlm_conversion_options + ) class PdfPipelineOptions(PaginatedPipelineOptions): diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py new file mode 100644 index 0000000..9520122 --- /dev/null +++ b/docling/models/api_vlm_model.py @@ -0,0 +1,67 @@ +from typing import Iterable + +from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import ApiVlmOptions +from docling.exceptions import OperationNotAllowed +from docling.models.base_model import BasePageModel +from docling.utils.api_image_request import api_image_request +from docling.utils.profiling import TimeRecorder + + +class ApiVlmModel(BasePageModel): + + def __init__( + self, + enabled: bool, + enable_remote_services: bool, + vlm_options: ApiVlmOptions, + ): + self.enabled = enabled + self.vlm_options = vlm_options + if self.enabled: + if not enable_remote_services: + raise OperationNotAllowed( + "Connections to remote services is only allowed when set explicitly. " + "pipeline_options.enable_remote_services=True, or using the CLI " + "--enable-remote-services." + ) + + self.timeout = self.vlm_options.timeout + self.prompt_content = ( + f"This is a page from a document.\n{self.vlm_options.prompt}" + ) + self.params = { + **self.vlm_options.params, + "temperature": 0, + } + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + for page in page_batch: + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "vlm"): + assert page.size is not None + + hi_res_image = page.get_image(scale=self.vlm_options.scale) + assert hi_res_image is not None + if hi_res_image: + if hi_res_image.mode != "RGB": + hi_res_image = hi_res_image.convert("RGB") + + page_tags = api_image_request( + image=hi_res_image, + prompt=self.prompt_content, + url=self.vlm_options.url, + timeout=self.timeout, + headers=self.vlm_options.headers, + **self.params, + ) + + page.predictions.vlm_response = VlmPrediction(text=page_tags) + + yield page diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index 6ef8a7f..1aa7351 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -1,12 +1,7 @@ -import base64 -import io -import logging from pathlib import Path -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable, Optional, Type, Union -import requests from PIL import Image -from pydantic import BaseModel, ConfigDict from docling.datamodel.pipeline_options import ( AcceleratorOptions, @@ -15,37 +10,7 @@ from docling.datamodel.pipeline_options import ( ) from docling.exceptions import OperationNotAllowed from docling.models.picture_description_base_model import PictureDescriptionBaseModel - -_log = logging.getLogger(__name__) - - -class ChatMessage(BaseModel): - role: str - content: str - - -class ResponseChoice(BaseModel): - index: int - message: ChatMessage - finish_reason: str - - -class ResponseUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class ApiResponse(BaseModel): - model_config = ConfigDict( - protected_namespaces=(), - ) - - id: str - model: Optional[str] = None # returned by openai - choices: List[ResponseChoice] - created: int - usage: ResponseUsage +from docling.utils.api_image_request import api_image_request class PictureDescriptionApiModel(PictureDescriptionBaseModel): @@ -83,43 +48,11 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): # Note: technically we could make a batch request here, # but not all APIs will allow for it. For example, vllm won't allow more than 1. for image in images: - img_io = io.BytesIO() - image.save(img_io, "PNG") - image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") - - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": self.options.prompt, - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{image_base64}" - }, - }, - ], - } - ] - - payload = { - "messages": messages, - **self.options.params, - } - - r = requests.post( - str(self.options.url), - headers=self.options.headers, - json=payload, + yield api_image_request( + image=image, + prompt=self.options.prompt, + url=self.options.url, timeout=self.options.timeout, + headers=self.options.headers, + **self.options.params, ) - if not r.ok: - _log.error(f"Error calling the API. Reponse was {r.text}") - r.raise_for_status() - - api_resp = ApiResponse.model_validate_json(r.text) - generated_text = api_resp.choices[0].message.content.strip() - yield generated_text diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index d4defa8..79279fd 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -15,11 +15,14 @@ from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import InputFormat, Page from docling.datamodel.document import ConversionResult, InputDocument from docling.datamodel.pipeline_options import ( + ApiVlmOptions, + HuggingFaceVlmOptions, InferenceFramework, ResponseFormat, VlmPipelineOptions, ) from docling.datamodel.settings import settings +from docling.models.api_vlm_model import ApiVlmModel from docling.models.hf_mlx_model import HuggingFaceMlxModel from docling.models.hf_vlm_model import HuggingFaceVlmModel from docling.pipeline.base_pipeline import PaginatedPipeline @@ -57,27 +60,34 @@ class VlmPipeline(PaginatedPipeline): self.keep_images = self.pipeline_options.generate_page_images - if ( - self.pipeline_options.vlm_options.inference_framework - == InferenceFramework.MLX - ): + if isinstance(pipeline_options.vlm_options, ApiVlmOptions): self.build_pipe = [ - HuggingFaceMlxModel( + ApiVlmModel( enabled=True, # must be always enabled for this pipeline to make sense. - artifacts_path=artifacts_path, - accelerator_options=pipeline_options.accelerator_options, - vlm_options=self.pipeline_options.vlm_options, - ), - ] - else: - self.build_pipe = [ - HuggingFaceVlmModel( - enabled=True, # must be always enabled for this pipeline to make sense. - artifacts_path=artifacts_path, - accelerator_options=pipeline_options.accelerator_options, - vlm_options=self.pipeline_options.vlm_options, + enable_remote_services=self.pipeline_options.enable_remote_services, + vlm_options=cast(ApiVlmOptions, self.pipeline_options.vlm_options), ), ] + elif isinstance(self.pipeline_options.vlm_options, HuggingFaceVlmOptions): + vlm_options = cast(HuggingFaceVlmOptions, self.pipeline_options.vlm_options) + if vlm_options.inference_framework == InferenceFramework.MLX: + self.build_pipe = [ + HuggingFaceMlxModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=vlm_options, + ), + ] + else: + self.build_pipe = [ + HuggingFaceVlmModel( + enabled=True, # must be always enabled for this pipeline to make sense. + artifacts_path=artifacts_path, + accelerator_options=pipeline_options.accelerator_options, + vlm_options=vlm_options, + ), + ] self.enrichment_pipe = [ # Other models working on `NodeItem` elements in the DoclingDocument diff --git a/docling/utils/api_image_request.py b/docling/utils/api_image_request.py new file mode 100644 index 0000000..9227389 --- /dev/null +++ b/docling/utils/api_image_request.py @@ -0,0 +1,61 @@ +import base64 +import logging +from io import BytesIO +from typing import Dict, Optional + +import requests +from PIL import Image +from pydantic import AnyUrl + +from docling.datamodel.base_models import OpenAiApiResponse + +_log = logging.getLogger(__name__) + + +def api_image_request( + image: Image.Image, + prompt: str, + url: AnyUrl, + timeout: float = 20, + headers: Optional[Dict[str, str]] = None, + **params, +) -> str: + img_io = BytesIO() + image.save(img_io, "PNG") + image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + { + "type": "text", + "text": prompt, + }, + ], + } + ] + + payload = { + "messages": messages, + **params, + } + + headers = headers or {} + + r = requests.post( + str(url), + headers=headers, + json=payload, + timeout=timeout, + ) + if not r.ok: + _log.error(f"Error calling the API. Response was {r.text}") + r.raise_for_status() + + api_resp = OpenAiApiResponse.model_validate_json(r.text) + generated_text = api_resp.choices[0].message.content.strip() + return generated_text diff --git a/docs/examples/vlm_pipeline_api_model.py b/docs/examples/vlm_pipeline_api_model.py new file mode 100644 index 0000000..33fb72a --- /dev/null +++ b/docs/examples/vlm_pipeline_api_model.py @@ -0,0 +1,111 @@ +import logging +import os +from pathlib import Path + +import requests +from dotenv import load_dotenv + +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import ( + ApiVlmOptions, + ResponseFormat, + VlmPipelineOptions, + granite_vision_vlm_ollama_conversion_options, +) +from docling.document_converter import DocumentConverter, PdfFormatOption +from docling.pipeline.vlm_pipeline import VlmPipeline + + +def ollama_vlm_options(model: str, prompt: str): + options = ApiVlmOptions( + url="http://localhost:11434/v1/chat/completions", # the default Ollama endpoint + params=dict( + model=model, + ), + prompt=prompt, + timeout=90, + scale=1.0, + response_format=ResponseFormat.MARKDOWN, + ) + return options + + +def watsonx_vlm_options(model: str, prompt: str): + load_dotenv() + api_key = os.environ.get("WX_API_KEY") + project_id = os.environ.get("WX_PROJECT_ID") + + def _get_iam_access_token(api_key: str) -> str: + res = requests.post( + url="https://iam.cloud.ibm.com/identity/token", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={api_key}", + ) + res.raise_for_status() + api_out = res.json() + print(f"{api_out=}") + return api_out["access_token"] + + options = ApiVlmOptions( + url="https://us-south.ml.cloud.ibm.com/ml/v1/text/chat?version=2023-05-29", + params=dict( + model_id=model, + project_id=project_id, + parameters=dict( + max_new_tokens=400, + ), + ), + headers={ + "Authorization": "Bearer " + _get_iam_access_token(api_key=api_key), + }, + prompt=prompt, + timeout=60, + response_format=ResponseFormat.MARKDOWN, + ) + return options + + +def main(): + logging.basicConfig(level=logging.INFO) + + # input_doc_path = Path("./tests/data/pdf/2206.01062.pdf") + input_doc_path = Path("./tests/data/pdf/2305.03393v1-pg9.pdf") + + pipeline_options = VlmPipelineOptions( + enable_remote_services=True # <-- this is required! + ) + + # The ApiVlmOptions() allows to interface with APIs supporting + # the multi-modal chat interface. Here follow a few example on how to configure those. + + # One possibility is self-hosting model, e.g. via Ollama. + # Example using the Granite Vision model: (uncomment the following lines) + pipeline_options.vlm_options = ollama_vlm_options( + model="granite3.2-vision:2b", + prompt="OCR the full page to markdown.", + ) + + # Another possibility is using online services, e.g. watsonx.ai. + # Using requires setting the env variables WX_API_KEY and WX_PROJECT_ID. + # Uncomment the following line for this option: + # pipeline_options.vlm_options = watsonx_vlm_options( + # model="ibm/granite-vision-3-2-2b", prompt="OCR the full page to markdown." + # ) + + # Create the DocumentConverter and launch the conversion. + doc_converter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + pipeline_options=pipeline_options, + pipeline_cls=VlmPipeline, + ) + } + ) + result = doc_converter.convert(input_doc_path) + print(result.document.export_to_markdown()) + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index 0fc7f5f..dd842d6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,6 +75,8 @@ nav: - "Custom conversion": examples/custom_convert.py - "Batch conversion": examples/batch_convert.py - "Multi-format conversion": examples/run_with_formats.py + - "VLM pipeline with SmolDocling": examples/minimal_vlm_pipeline.py + - "VLM pipeline with remote model": examples/vlm_pipeline_api_model.py - "Figure export": examples/export_figures.py - "Table export": examples/export_tables.py - "Multimodal export": examples/export_multimodal.py