diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 75ea597..89bcfd7 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -28,7 +28,7 @@ jobs: run: | for file in docs/examples/*.py; do # Skip batch_convert.py - if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert).py ]]; then + if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then echo "Skipping $file" continue fi diff --git a/docling/cli/main.py b/docling/cli/main.py index 19f77e4..e2bc0dd 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -226,6 +226,10 @@ def convert( help="Enable the picture classification enrichment model in the pipeline.", ), ] = False, + enrich_picture_description: Annotated[ + bool, + typer.Option(..., help="Enable the picture description model in the pipeline."), + ] = False, artifacts_path: Annotated[ Optional[Path], typer.Option(..., help="If provided, the location of the model artifacts."), @@ -382,6 +386,7 @@ def convert( do_table_structure=True, do_code_enrichment=enrich_code, do_formula_enrichment=enrich_formula, + do_picture_description=enrich_picture_description, do_picture_classification=enrich_picture_classes, document_timeout=document_timeout, ) diff --git a/docling/cli/models.py b/docling/cli/models.py index aea498c..3b62ad6 100644 --- a/docling/cli/models.py +++ b/docling/cli/models.py @@ -31,6 +31,7 @@ class _AvailableModels(str, Enum): TABLEFORMER = "tableformer" CODE_FORMULA = "code_formula" PICTURE_CLASSIFIER = "picture_classifier" + SMOLVLM = "smolvlm" EASYOCR = "easyocr" @@ -81,6 +82,7 @@ def download( with_tableformer=_AvailableModels.TABLEFORMER in to_download, with_code_formula=_AvailableModels.CODE_FORMULA in to_download, with_picture_classifier=_AvailableModels.PICTURE_CLASSIFIER in to_download, + with_smolvlm=_AvailableModels.SMOLVLM in to_download, with_easyocr=_AvailableModels.EASYOCR in to_download, ) diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 14ca75b..3b6401b 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -2,9 +2,9 @@ import logging import os from enum import Enum from pathlib import Path -from typing import Any, List, Literal, Optional, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict _log = logging.getLogger(__name__) @@ -184,6 +184,51 @@ class OcrMacOptions(OcrOptions): ) +class PictureDescriptionBaseOptions(BaseModel): + kind: str + batch_size: int = 8 + scale: float = 2 + + bitmap_area_threshold: float = ( + 0.2 # percentage of the area for a bitmap to processed with the models + ) + + +class PictureDescriptionApiOptions(PictureDescriptionBaseOptions): + kind: Literal["api"] = "api" + + url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions") + headers: Dict[str, str] = {} + params: Dict[str, Any] = {} + timeout: float = 20 + + prompt: str = "Describe this image in a few sentences." + provenance: str = "" + + +class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions): + kind: Literal["vlm"] = "vlm" + + repo_id: str + prompt: str = "Describe this image in a few sentences." + # Config from here https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig + generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False) + + @property + def repo_cache_folder(self) -> str: + return self.repo_id.replace("/", "--") + + +smolvlm_picture_description = PictureDescriptionVlmOptions( + repo_id="HuggingFaceTB/SmolVLM-256M-Instruct" +) +# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct") +granite_picture_description = PictureDescriptionVlmOptions( + repo_id="ibm-granite/granite-vision-3.1-2b-preview", + prompt="What is shown in this image?", +) + + # Define an enum for the backend options class PdfBackend(str, Enum): """Enum of valid PDF backends.""" @@ -223,6 +268,7 @@ class PdfPipelineOptions(PipelineOptions): do_code_enrichment: bool = False # True: perform code OCR do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code do_picture_classification: bool = False # True: classify pictures in documents + do_picture_description: bool = False # True: run describe pictures in documents table_structure_options: TableStructureOptions = TableStructureOptions() ocr_options: Union[ @@ -232,6 +278,10 @@ class PdfPipelineOptions(PipelineOptions): OcrMacOptions, RapidOcrOptions, ] = Field(EasyOcrOptions(), discriminator="kind") + picture_description_options: Annotated[ + Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions], + Field(discriminator="kind"), + ] = smolvlm_picture_description images_scale: float = 1.0 generate_page_images: bool = False diff --git a/docling/models/base_model.py b/docling/models/base_model.py index a2bc776..9cdc0ec 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Iterable, Optional -from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem +from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from typing_extensions import TypeVar from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page @@ -64,7 +64,7 @@ class BaseItemAndImageEnrichmentModel( if not self.is_processable(doc=conv_res.document, element=element): return None - assert isinstance(element, TextItem) + assert isinstance(element, DocItem) element_prov = element.prov[0] bbox = element_prov.bbox diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py new file mode 100644 index 0000000..6c7e02f --- /dev/null +++ b/docling/models/picture_description_api_model.py @@ -0,0 +1,105 @@ +import base64 +import io +import logging +from typing import Iterable, List, Optional + +import httpx +from docling_core.types.doc import PictureItem +from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc + PictureDescriptionData, +) +from PIL import Image +from pydantic import BaseModel, ConfigDict + +from docling.datamodel.pipeline_options import PictureDescriptionApiOptions +from docling.models.picture_description_base_model import PictureDescriptionBaseModel + +_log = logging.getLogger(__name__) + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: str + + +class ResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ApiResponse(BaseModel): + model_config = ConfigDict( + protected_namespaces=(), + ) + + id: str + model: Optional[str] = None # returned by openai + choices: List[ResponseChoice] + created: int + usage: ResponseUsage + + +class PictureDescriptionApiModel(PictureDescriptionBaseModel): + # elements_batch_size = 4 + + def __init__(self, enabled: bool, options: PictureDescriptionApiOptions): + super().__init__(enabled=enabled, options=options) + self.options: PictureDescriptionApiOptions + + if self.enabled: + if options.url.host != "localhost": + raise NotImplementedError( + "The options try to connect to remote APIs which are not yet allowed." + ) + + def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + # Note: technically we could make a batch request here, + # but not all APIs will allow for it. For example, vllm won't allow more than 1. + for image in images: + img_io = io.BytesIO() + image.save(img_io, "PNG") + image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": self.options.prompt, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{image_base64}" + }, + }, + ], + } + ] + + payload = { + "messages": messages, + **self.options.params, + } + + r = httpx.post( + str(self.options.url), + headers=self.options.headers, + json=payload, + timeout=self.options.timeout, + ) + if not r.is_success: + _log.error(f"Error calling the API. Reponse was {r.text}") + r.raise_for_status() + + api_resp = ApiResponse.model_validate_json(r.text) + generated_text = api_resp.choices[0].message.content.strip() + yield generated_text diff --git a/docling/models/picture_description_base_model.py b/docling/models/picture_description_base_model.py new file mode 100644 index 0000000..b653e0e --- /dev/null +++ b/docling/models/picture_description_base_model.py @@ -0,0 +1,64 @@ +import logging +from pathlib import Path +from typing import Any, Iterable, List, Optional, Union + +from docling_core.types.doc import ( + DoclingDocument, + NodeItem, + PictureClassificationClass, + PictureItem, +) +from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc + PictureDescriptionData, +) +from PIL import Image + +from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions +from docling.models.base_model import ( + BaseItemAndImageEnrichmentModel, + ItemAndImageEnrichmentElement, +) + + +class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): + images_scale: float = 2.0 + + def __init__( + self, + enabled: bool, + options: PictureDescriptionBaseOptions, + ): + self.enabled = enabled + self.options = options + self.provenance = "not-implemented" + + def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: + return self.enabled and isinstance(element, PictureItem) + + def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + raise NotImplementedError + + def __call__( + self, + doc: DoclingDocument, + element_batch: Iterable[ItemAndImageEnrichmentElement], + ) -> Iterable[NodeItem]: + if not self.enabled: + for element in element_batch: + yield element.item + return + + images: List[Image.Image] = [] + elements: List[PictureItem] = [] + for el in element_batch: + assert isinstance(el.item, PictureItem) + elements.append(el.item) + images.append(el.image) + + outputs = self._annotate_images(images) + + for item, output in zip(elements, outputs): + item.annotations.append( + PictureDescriptionData(text=output, provenance=self.provenance) + ) + yield item diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py new file mode 100644 index 0000000..9fa4826 --- /dev/null +++ b/docling/models/picture_description_vlm_model.py @@ -0,0 +1,109 @@ +from pathlib import Path +from typing import Iterable, Optional, Union + +from PIL import Image + +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + PictureDescriptionVlmOptions, +) +from docling.models.picture_description_base_model import PictureDescriptionBaseModel +from docling.utils.accelerator_utils import decide_device + + +class PictureDescriptionVlmModel(PictureDescriptionBaseModel): + + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Union[Path, str]], + options: PictureDescriptionVlmOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__(enabled=enabled, options=options) + self.options: PictureDescriptionVlmOptions + + if self.enabled: + + if artifacts_path is None: + artifacts_path = self.download_models(repo_id=self.options.repo_id) + else: + artifacts_path = Path(artifacts_path) / self.options.repo_cache_folder + + self.device = decide_device(accelerator_options.device) + + try: + import torch + from transformers import AutoModelForVision2Seq, AutoProcessor + except ImportError: + raise ImportError( + "transformers >=4.46 is not installed. Please install Docling with the required extras `pip install docling[vlm]`." + ) + + # Initialize processor and model + self.processor = AutoProcessor.from_pretrained(self.options.repo_id) + self.model = AutoModelForVision2Seq.from_pretrained( + self.options.repo_id, + torch_dtype=torch.bfloat16, + _attn_implementation=( + "flash_attention_2" if self.device.startswith("cuda") else "eager" + ), + ).to(self.device) + + self.provenance = f"{self.options.repo_id}" + + @staticmethod + def download_models( + repo_id: str, + local_dir: Optional[Path] = None, + force: bool = False, + progress: bool = False, + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + if not progress: + disable_progress_bars() + download_path = snapshot_download( + repo_id=repo_id, + force_download=force, + local_dir=local_dir, + ) + + return Path(download_path) + + def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + from transformers import GenerationConfig + + # Create input messages + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": self.options.prompt}, + ], + }, + ] + + # TODO: do batch generation + + for image in images: + # Prepare inputs + prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True + ) + inputs = self.processor(text=prompt, images=[image], return_tensors="pt") + inputs = inputs.to(self.device) + + # Generate outputs + generated_ids = self.model.generate( + **inputs, + generation_config=GenerationConfig(**self.options.generation_config), + ) + generated_texts = self.processor.batch_decode( + generated_ids[:, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + ) + + yield generated_texts[0].strip() diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 4e66415..13e435f 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -14,6 +14,8 @@ from docling.datamodel.pipeline_options import ( EasyOcrOptions, OcrMacOptions, PdfPipelineOptions, + PictureDescriptionApiOptions, + PictureDescriptionVlmOptions, RapidOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions, @@ -34,6 +36,9 @@ from docling.models.page_preprocessing_model import ( PagePreprocessingModel, PagePreprocessingOptions, ) +from docling.models.picture_description_api_model import PictureDescriptionApiModel +from docling.models.picture_description_base_model import PictureDescriptionBaseModel +from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.table_structure_model import TableStructureModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel @@ -95,8 +100,17 @@ class StandardPdfPipeline(PaginatedPipeline): PageAssembleModel(options=PageAssembleOptions()), ] + # Picture description model + if ( + picture_description_model := self.get_picture_description_model( + artifacts_path=artifacts_path + ) + ) is None: + raise RuntimeError( + f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}." + ) + self.enrichment_pipe = [ - # Other models working on `NodeItem` elements in the DoclingDocument # Code Formula Enrichment Model CodeFormulaModel( enabled=pipeline_options.do_code_enrichment @@ -115,11 +129,14 @@ class StandardPdfPipeline(PaginatedPipeline): options=DocumentPictureClassifierOptions(), accelerator_options=pipeline_options.accelerator_options, ), + # Document Picture description + picture_description_model, ] if ( self.pipeline_options.do_formula_enrichment or self.pipeline_options.do_code_enrichment + or self.pipeline_options.do_picture_description ): self.keep_backend = True @@ -175,6 +192,29 @@ class StandardPdfPipeline(PaginatedPipeline): ) return None + def get_picture_description_model( + self, artifacts_path: Optional[Path] = None + ) -> Optional[PictureDescriptionBaseModel]: + if isinstance( + self.pipeline_options.picture_description_options, + PictureDescriptionApiOptions, + ): + return PictureDescriptionApiModel( + enabled=self.pipeline_options.do_picture_description, + options=self.pipeline_options.picture_description_options, + ) + elif isinstance( + self.pipeline_options.picture_description_options, + PictureDescriptionVlmOptions, + ): + return PictureDescriptionVlmModel( + enabled=self.pipeline_options.do_picture_description, + artifacts_path=artifacts_path, + options=self.pipeline_options.picture_description_options, + accelerator_options=self.pipeline_options.accelerator_options, + ) + return None + def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: with TimeRecorder(conv_res, "page_init"): page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore diff --git a/docling/utils/model_downloader.py b/docling/utils/model_downloader.py index 504618e..7d22b77 100644 --- a/docling/utils/model_downloader.py +++ b/docling/utils/model_downloader.py @@ -2,11 +2,13 @@ import logging from pathlib import Path from typing import Optional +from docling.datamodel.pipeline_options import smolvlm_picture_description from docling.datamodel.settings import settings from docling.models.code_formula_model import CodeFormulaModel from docling.models.document_picture_classifier import DocumentPictureClassifier from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel +from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel from docling.models.table_structure_model import TableStructureModel _log = logging.getLogger(__name__) @@ -21,6 +23,7 @@ def download_models( with_tableformer: bool = True, with_code_formula: bool = True, with_picture_classifier: bool = True, + with_smolvlm: bool = True, with_easyocr: bool = True, ): if output_dir is None: @@ -61,6 +64,15 @@ def download_models( progress=progress, ) + if with_smolvlm: + _log.info(f"Downloading SmolVlm model...") + PictureDescriptionVlmModel.download_models( + repo_id=smolvlm_picture_description.repo_id, + local_dir=output_dir / smolvlm_picture_description.repo_cache_folder, + force=force, + progress=progress, + ) + if with_easyocr: _log.info(f"Downloading easyocr models...") EasyOcrModel.download_models( diff --git a/docs/examples/pictures_description.py b/docs/examples/pictures_description.py new file mode 100644 index 0000000..f60ac29 --- /dev/null +++ b/docs/examples/pictures_description.py @@ -0,0 +1,48 @@ +import logging +from pathlib import Path + +from docling_core.types.doc import PictureItem + +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import ( + PdfPipelineOptions, + granite_picture_description, + smolvlm_picture_description, +) +from docling.document_converter import DocumentConverter, PdfFormatOption + + +def main(): + logging.basicConfig(level=logging.INFO) + + input_doc_path = Path("./tests/data/pdf/2206.01062.pdf") + + pipeline_options = PdfPipelineOptions() + pipeline_options.do_picture_description = True + pipeline_options.picture_description_options = smolvlm_picture_description + # pipeline_options.picture_description_options = granite_picture_description + + pipeline_options.picture_description_options.prompt = ( + "Describe the image in three sentences. Be consise and accurate." + ) + + doc_converter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + pipeline_options=pipeline_options, + ) + } + ) + result = doc_converter.convert(input_doc_path) + + for element, _level in result.document.iterate_items(): + if isinstance(element, PictureItem): + print( + f"Picture {element.self_ref}\n" + f"Caption: {element.caption_text(doc=result.document)}\n" + f"Annotations: {element.annotations}" + ) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/pictures_description_api.py b/docs/examples/pictures_description_api.py new file mode 100644 index 0000000..3da37ed --- /dev/null +++ b/docs/examples/pictures_description_api.py @@ -0,0 +1,55 @@ +import logging +from pathlib import Path + +from docling_core.types.doc import PictureItem + +from docling.datamodel.base_models import InputFormat +from docling.datamodel.pipeline_options import ( + PdfPipelineOptions, + PictureDescriptionApiOptions, +) +from docling.document_converter import DocumentConverter, PdfFormatOption + + +def main(): + logging.basicConfig(level=logging.INFO) + + input_doc_path = Path("./tests/data/pdf/2206.01062.pdf") + + # This is using a local API server to do picture description. + # For example, you can launch it locally with: + # $ vllm serve "HuggingFaceTB/SmolVLM-256M-Instruct" + + pipeline_options = PdfPipelineOptions() + pipeline_options.do_picture_description = True + pipeline_options.picture_description_options = PictureDescriptionApiOptions( + url="http://localhost:8000/v1/chat/completions", + params=dict( + model="HuggingFaceTB/SmolVLM-256M-Instruct", + seed=42, + max_completion_tokens=200, + ), + prompt="Describe the image in three sentences. Be consise and accurate.", + timeout=90, + ) + + doc_converter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption( + pipeline_options=pipeline_options, + ) + } + ) + result = doc_converter.convert(input_doc_path) + + for element, _level in result.document.iterate_items(): + if isinstance(element, PictureItem): + print( + f"Picture {element.self_ref}\n" + f"Caption: {element.caption_text(doc=result.document)}\n" + f"Annotations: {element.annotations}" + ) + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index b261db4..691dd84 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2727,13 +2727,13 @@ pygments = ">2.12.0" [[package]] name = "mkdocs-material" -version = "9.6.2" +version = "9.6.3" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.6.2-py3-none-any.whl", hash = "sha256:71d90dbd63b393ad11a4d90151dfe3dcbfcd802c0f29ce80bebd9bbac6abc753"}, - {file = "mkdocs_material-9.6.2.tar.gz", hash = "sha256:a3de1c5d4c745f10afa78b1a02f917b9dce0808fb206adc0f5bb48b58c1ca21f"}, + {file = "mkdocs_material-9.6.3-py3-none-any.whl", hash = "sha256:1125622067e26940806701219303b27c0933e04533560725d97ec26fd16a39cf"}, + {file = "mkdocs_material-9.6.3.tar.gz", hash = "sha256:c87f7d1c39ce6326da5e10e232aed51bae46252e646755900f4b0fc9192fa832"}, ] [package.dependencies] @@ -7846,8 +7846,9 @@ type = ["pytest-mypy"] ocrmac = ["ocrmac"] rapidocr = ["onnxruntime", "onnxruntime", "rapidocr-onnxruntime"] tesserocr = ["tesserocr"] +vlm = ["transformers", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ca0464df452664834ae9bccc59f89240e2f5e8f3b179761de615548c799680e7" +content-hash = "86d266adc6272f3db65ab07f5cce35cbe9626368dc0e09ab374c861f0809f693" diff --git a/pyproject.toml b/pyproject.toml index 3bc88b0..9b1b5e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,10 @@ onnxruntime = [ { version = ">=1.7.0,<1.20.0", optional = true, markers = "python_version < '3.10'" }, { version = "^1.7.0", optional = true, markers = "python_version >= '3.10'" } ] +transformers = [ + {markers = "sys_platform != 'darwin' or platform_machine != 'x86_64'", version = "^4.46.0", optional = true }, + {markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", version = "~4.42.0", optional = true } +] pillow = "^10.0.0" tqdm = "^4.65.0" @@ -121,6 +125,7 @@ torchvision = [ [tool.poetry.extras] tesserocr = ["tesserocr"] ocrmac = ["ocrmac"] +vlm = ["transformers"] rapidocr = ["rapidocr-onnxruntime", "onnxruntime"] [tool.poetry.scripts] @@ -162,7 +167,8 @@ module = [ "deepsearch_glm.*", "lxml.*", "bs4.*", - "huggingface_hub.*" + "huggingface_hub.*", + "transformers.*", ] ignore_missing_imports = true