Docling/docling/models/api_vlm_model.py
Vinay R Damodaran 3a04f2a367
feat: Improve parallelization for remote services API calls (#1548)
* Provide the option to make remote services call concurrent

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

* Use yield from correctly?

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

* not do amateur hour stuff

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

---------

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
2025-05-14 15:47:55 +02:00

72 lines
2.7 KiB
Python

from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ApiVlmOptions
from docling.exceptions import OperationNotAllowed
from docling.models.base_model import BasePageModel
from docling.utils.api_image_request import api_image_request
from docling.utils.profiling import TimeRecorder
class ApiVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
enable_remote_services: bool,
vlm_options: ApiVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
if not enable_remote_services:
raise OperationNotAllowed(
"Connections to remote services is only allowed when set explicitly. "
"pipeline_options.enable_remote_services=True, or using the CLI "
"--enable-remote-services."
)
self.timeout = self.vlm_options.timeout
self.concurrency = self.vlm_options.concurrency
self.prompt_content = (
f"This is a page from a document.\n{self.vlm_options.prompt}"
)
self.params = {
**self.vlm_options.params,
"temperature": 0,
}
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
def _vlm_request(page):
assert page._backend is not None
if not page._backend.is_valid():
return page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
assert hi_res_image is not None
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
page_tags = api_image_request(
image=hi_res_image,
prompt=self.prompt_content,
url=self.vlm_options.url,
timeout=self.timeout,
headers=self.vlm_options.headers,
**self.params,
)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
return page
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_vlm_request, page_batch)