feat: Improve parallelization for remote services API calls (#1548)

* Provide the option to make remote services call concurrent

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

* Use yield from correctly?

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

* not do amateur hour stuff

Signed-off-by: Vinay Damodaran <vrdn@hey.com>

---------

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
This commit is contained in:
Vinay R Damodaran 2025-05-14 06:47:55 -07:00 committed by GitHub
parent 9f8b479f17
commit 3a04f2a367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 5 deletions

View File

@ -225,6 +225,7 @@ class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):
headers: Dict[str, str] = {} headers: Dict[str, str] = {}
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
timeout: float = 20 timeout: float = 20
concurrency: int = 1
prompt: str = "Describe this image in a few sentences." prompt: str = "Describe this image in a few sentences."
provenance: str = "" provenance: str = ""
@ -295,6 +296,7 @@ class ApiVlmOptions(BaseVlmOptions):
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
scale: float = 2.0 scale: float = 2.0
timeout: float = 60 timeout: float = 60
concurrency: int = 1
response_format: ResponseFormat response_format: ResponseFormat

View File

@ -1,4 +1,5 @@
from collections.abc import Iterable from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
@ -27,6 +28,7 @@ class ApiVlmModel(BasePageModel):
) )
self.timeout = self.vlm_options.timeout self.timeout = self.vlm_options.timeout
self.concurrency = self.vlm_options.concurrency
self.prompt_content = ( self.prompt_content = (
f"This is a page from a document.\n{self.vlm_options.prompt}" f"This is a page from a document.\n{self.vlm_options.prompt}"
) )
@ -38,10 +40,10 @@ class ApiVlmModel(BasePageModel):
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]: ) -> Iterable[Page]:
for page in page_batch: def _vlm_request(page):
assert page._backend is not None assert page._backend is not None
if not page._backend.is_valid(): if not page._backend.is_valid():
yield page return page
else: else:
with TimeRecorder(conv_res, "vlm"): with TimeRecorder(conv_res, "vlm"):
assert page.size is not None assert page.size is not None
@ -63,4 +65,7 @@ class ApiVlmModel(BasePageModel):
page.predictions.vlm_response = VlmPrediction(text=page_tags) page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page return page
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_vlm_request, page_batch)

View File

@ -1,4 +1,5 @@
from collections.abc import Iterable from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Optional, Type, Union from typing import Optional, Type, Union
@ -37,6 +38,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
accelerator_options=accelerator_options, accelerator_options=accelerator_options,
) )
self.options: PictureDescriptionApiOptions self.options: PictureDescriptionApiOptions
self.concurrency = self.options.concurrency
if self.enabled: if self.enabled:
if not enable_remote_services: if not enable_remote_services:
@ -48,8 +50,8 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
# Note: technically we could make a batch request here, # Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1. # but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image in images: def _api_request(image):
yield api_image_request( return api_image_request(
image=image, image=image,
prompt=self.options.prompt, prompt=self.options.prompt,
url=self.options.url, url=self.options.url,
@ -57,3 +59,6 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
headers=self.options.headers, headers=self.options.headers,
**self.options.params, **self.options.params,
) )
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_api_request, images)