Docling/docling/models/picture_description_api_model.py
Vinay R Damodaran 3a04f2a367
feat: Improve parallelization for remote services API calls (#1548)
* Provide the option to make remote services call concurrent

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

* Use yield from correctly?

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

* not do amateur hour stuff

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

---------

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
2025-05-14 15:47:55 +02:00

65 lines
2.3 KiB
Python

from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional, Type, Union
from PIL import Image
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionApiOptions,
PictureDescriptionBaseOptions,
)
from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.api_image_request import api_image_request
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionApiOptions
def __init__(
self,
enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionApiOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(
enabled=enabled,
enable_remote_services=enable_remote_services,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: PictureDescriptionApiOptions
self.concurrency = self.options.concurrency
if self.enabled:
if not enable_remote_services:
raise OperationNotAllowed(
"Connections to remote services is only allowed when set explicitly. "
"pipeline_options.enable_remote_services=True."
)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
def _api_request(image):
return api_image_request(
image=image,
prompt=self.options.prompt,
url=self.options.url,
timeout=self.options.timeout,
headers=self.options.headers,
**self.options.params,
)
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_api_request, images)