feat: OllamaVlmModel for Granite Vision 3.2 (#1337)

* build: Add ollama sdk dependency

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add option plumbing for OllamaVlmOptions in pipeline_options

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Full implementation of OllamaVlmModel

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Connect "granite_vision_ollama" pipeline option to CLI

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* Revert "build: Add ollama sdk dependency"

After consideration, we're going to use the generic OpenAI API instead
of the Ollama-specific API to avoid duplicate work.

This reverts commit bc6b366468cdd66b52540aac9c7d8b584ab48ad0.

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Move OpenAI API call logic into utils.utils

This will allow reuse of this logic in a generic VLM model

NOTE: There is a subtle change here in the ordering of the text prompt and
the image in the call to the OpenAI API. When run against Ollama, this
ordering makes a big difference. If the prompt comes before the image, the
result is terse and not usable whereas the prompt coming after the image
works as expected and matches the non-OpenAI chat API.

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Refactor from Ollama SDK to generic OpenAI API

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Linting, formatting, and bug fixes

The one bug fix was in the timeout arg to openai_image_request. Otherwise,
this is all style changes to get MyPy and black passing cleanly.

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* remove model from download enum

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* generalize input args for other API providers

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename and refactor

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add example

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* require flag for remote services

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* disable example from CI

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add examples to docs

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Gabe Goodhart 2025-04-10 10:03:04 -06:00 committed by GitHub
parent 6b696b504a
commit c605edd8e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 344 additions and 95 deletions

View File

@ -37,7 +37,7 @@ jobs:
run: |
for file in docs/examples/*.py; do
# Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then
echo "Skipping $file"
continue
fi

View File

@ -40,6 +40,7 @@ from docling.datamodel.pipeline_options import (
VlmModelType,
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
granite_vision_vlm_ollama_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
@ -531,10 +532,16 @@ def convert(
backend=backend, # pdf_backend
)
elif pipeline == PdfPipeline.VLM:
pipeline_options = VlmPipelineOptions()
pipeline_options = VlmPipelineOptions(
enable_remote_services=enable_remote_services,
)
if vlm_model == VlmModelType.GRANITE_VISION:
pipeline_options.vlm_options = granite_vision_vlm_conversion_options
elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA:
pipeline_options.vlm_options = (
granite_vision_vlm_ollama_conversion_options
)
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
if sys.platform == "darwin":

View File

@ -262,3 +262,35 @@ class Page(BaseModel):
@property
def image(self) -> Optional[Image]:
return self.get_image(scale=self._default_image_scale)
## OpenAI API Request / Response Models ##
class OpenAiChatMessage(BaseModel):
role: str
content: str
class OpenAiResponseChoice(BaseModel):
index: int
message: OpenAiChatMessage
finish_reason: str
class OpenAiResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class OpenAiApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)
id: str
model: Optional[str] = None # returned by openai
choices: List[OpenAiResponseChoice]
created: int
usage: OpenAiResponseUsage

View File

@ -266,6 +266,7 @@ class ResponseFormat(str, Enum):
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
OPENAI = "openai"
class HuggingFaceVlmOptions(BaseVlmOptions):
@ -284,6 +285,19 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
return self.repo_id.replace("/", "--")
class ApiVlmOptions(BaseVlmOptions):
kind: Literal["api_model_options"] = "api_model_options"
url: AnyUrl = AnyUrl(
"http://localhost:11434/v1/chat/completions"
) # Default to ollama
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
scale: float = 2.0
timeout: float = 60
response_format: ResponseFormat
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
prompt="Convert this page to docling.",
@ -307,10 +321,20 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
inference_framework=InferenceFramework.TRANSFORMERS,
)
granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
prompt="OCR the full page to markdown.",
scale=1.0,
timeout=120,
response_format=ResponseFormat.MARKDOWN,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
GRANITE_VISION = "granite_vision"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
# Define an enum for the backend options
@ -362,7 +386,9 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options
vlm_options: Union[HuggingFaceVlmOptions, ApiVlmOptions] = (
smoldocling_vlm_conversion_options
)
class PdfPipelineOptions(PaginatedPipelineOptions):

View File

@ -0,0 +1,67 @@
from typing import Iterable
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ApiVlmOptions
from docling.exceptions import OperationNotAllowed
from docling.models.base_model import BasePageModel
from docling.utils.api_image_request import api_image_request
from docling.utils.profiling import TimeRecorder
class ApiVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
enable_remote_services: bool,
vlm_options: ApiVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
if not enable_remote_services:
raise OperationNotAllowed(
"Connections to remote services is only allowed when set explicitly. "
"pipeline_options.enable_remote_services=True, or using the CLI "
"--enable-remote-services."
)
self.timeout = self.vlm_options.timeout
self.prompt_content = (
f"This is a page from a document.\n{self.vlm_options.prompt}"
)
self.params = {
**self.vlm_options.params,
"temperature": 0,
}
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
assert hi_res_image is not None
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
page_tags = api_image_request(
image=hi_res_image,
prompt=self.prompt_content,
url=self.vlm_options.url,
timeout=self.timeout,
headers=self.vlm_options.headers,
**self.params,
)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -1,12 +1,7 @@
import base64
import io
import logging
from pathlib import Path
from typing import Iterable, List, Optional, Type, Union
from typing import Iterable, Optional, Type, Union
import requests
from PIL import Image
from pydantic import BaseModel, ConfigDict
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
@ -15,37 +10,7 @@ from docling.datamodel.pipeline_options import (
)
from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
_log = logging.getLogger(__name__)
class ChatMessage(BaseModel):
role: str
content: str
class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)
id: str
model: Optional[str] = None # returned by openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage
from docling.utils.api_image_request import api_image_request
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
@ -83,43 +48,11 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image in images:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]
payload = {
"messages": messages,
**self.options.params,
}
r = requests.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
yield api_image_request(
image=image,
prompt=self.options.prompt,
url=self.options.url,
timeout=self.options.timeout,
headers=self.options.headers,
**self.options.params,
)
if not r.ok:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()
api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
yield generated_text

View File

@ -15,11 +15,14 @@ from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
ApiVlmOptions,
HuggingFaceVlmOptions,
InferenceFramework,
ResponseFormat,
VlmPipelineOptions,
)
from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel
from docling.models.hf_mlx_model import HuggingFaceMlxModel
from docling.models.hf_vlm_model import HuggingFaceVlmModel
from docling.pipeline.base_pipeline import PaginatedPipeline
@ -57,27 +60,34 @@ class VlmPipeline(PaginatedPipeline):
self.keep_images = self.pipeline_options.generate_page_images
if (
self.pipeline_options.vlm_options.inference_framework
== InferenceFramework.MLX
):
if isinstance(pipeline_options.vlm_options, ApiVlmOptions):
self.build_pipe = [
HuggingFaceMlxModel(
ApiVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
),
]
else:
self.build_pipe = [
HuggingFaceVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
enable_remote_services=self.pipeline_options.enable_remote_services,
vlm_options=cast(ApiVlmOptions, self.pipeline_options.vlm_options),
),
]
elif isinstance(self.pipeline_options.vlm_options, HuggingFaceVlmOptions):
vlm_options = cast(HuggingFaceVlmOptions, self.pipeline_options.vlm_options)
if vlm_options.inference_framework == InferenceFramework.MLX:
self.build_pipe = [
HuggingFaceMlxModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
else:
self.build_pipe = [
HuggingFaceVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument

View File

@ -0,0 +1,61 @@
import base64
import logging
from io import BytesIO
from typing import Dict, Optional
import requests
from PIL import Image
from pydantic import AnyUrl
from docling.datamodel.base_models import OpenAiApiResponse
_log = logging.getLogger(__name__)
def api_image_request(
image: Image.Image,
prompt: str,
url: AnyUrl,
timeout: float = 20,
headers: Optional[Dict[str, str]] = None,
**params,
) -> str:
img_io = BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
payload = {
"messages": messages,
**params,
}
headers = headers or {}
r = requests.post(
str(url),
headers=headers,
json=payload,
timeout=timeout,
)
if not r.ok:
_log.error(f"Error calling the API. Response was {r.text}")
r.raise_for_status()
api_resp = OpenAiApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
return generated_text

View File

@ -0,0 +1,111 @@
import logging
import os
from pathlib import Path
import requests
from dotenv import load_dotenv
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
ApiVlmOptions,
ResponseFormat,
VlmPipelineOptions,
granite_vision_vlm_ollama_conversion_options,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
def ollama_vlm_options(model: str, prompt: str):
options = ApiVlmOptions(
url="http://localhost:11434/v1/chat/completions", # the default Ollama endpoint
params=dict(
model=model,
),
prompt=prompt,
timeout=90,
scale=1.0,
response_format=ResponseFormat.MARKDOWN,
)
return options
def watsonx_vlm_options(model: str, prompt: str):
load_dotenv()
api_key = os.environ.get("WX_API_KEY")
project_id = os.environ.get("WX_PROJECT_ID")
def _get_iam_access_token(api_key: str) -> str:
res = requests.post(
url="https://iam.cloud.ibm.com/identity/token",
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={api_key}",
)
res.raise_for_status()
api_out = res.json()
print(f"{api_out=}")
return api_out["access_token"]
options = ApiVlmOptions(
url="https://us-south.ml.cloud.ibm.com/ml/v1/text/chat?version=2023-05-29",
params=dict(
model_id=model,
project_id=project_id,
parameters=dict(
max_new_tokens=400,
),
),
headers={
"Authorization": "Bearer " + _get_iam_access_token(api_key=api_key),
},
prompt=prompt,
timeout=60,
response_format=ResponseFormat.MARKDOWN,
)
return options
def main():
logging.basicConfig(level=logging.INFO)
# input_doc_path = Path("./tests/data/pdf/2206.01062.pdf")
input_doc_path = Path("./tests/data/pdf/2305.03393v1-pg9.pdf")
pipeline_options = VlmPipelineOptions(
enable_remote_services=True # <-- this is required!
)
# The ApiVlmOptions() allows to interface with APIs supporting
# the multi-modal chat interface. Here follow a few example on how to configure those.
# One possibility is self-hosting model, e.g. via Ollama.
# Example using the Granite Vision model: (uncomment the following lines)
pipeline_options.vlm_options = ollama_vlm_options(
model="granite3.2-vision:2b",
prompt="OCR the full page to markdown.",
)
# Another possibility is using online services, e.g. watsonx.ai.
# Using requires setting the env variables WX_API_KEY and WX_PROJECT_ID.
# Uncomment the following line for this option:
# pipeline_options.vlm_options = watsonx_vlm_options(
# model="ibm/granite-vision-3-2-2b", prompt="OCR the full page to markdown."
# )
# Create the DocumentConverter and launch the conversion.
doc_converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options,
pipeline_cls=VlmPipeline,
)
}
)
result = doc_converter.convert(input_doc_path)
print(result.document.export_to_markdown())
if __name__ == "__main__":
main()

View File

@ -75,6 +75,8 @@ nav:
- "Custom conversion": examples/custom_convert.py
- "Batch conversion": examples/batch_convert.py
- "Multi-format conversion": examples/run_with_formats.py
- "VLM pipeline with SmolDocling": examples/minimal_vlm_pipeline.py
- "VLM pipeline with remote model": examples/vlm_pipeline_api_model.py
- "Figure export": examples/export_figures.py
- "Table export": examples/export_tables.py
- "Multimodal export": examples/export_multimodal.py