diff --git a/docling/cli/main.py b/docling/cli/main.py index 7d31221..19f77e4 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -219,6 +219,13 @@ def convert( bool, typer.Option(..., help="Enable the formula enrichment model in the pipeline."), ] = False, + enrich_picture_classes: Annotated[ + bool, + typer.Option( + ..., + help="Enable the picture classification enrichment model in the pipeline.", + ), + ] = False, artifacts_path: Annotated[ Optional[Path], typer.Option(..., help="If provided, the location of the model artifacts."), @@ -375,6 +382,7 @@ def convert( do_table_structure=True, do_code_enrichment=enrich_code, do_formula_enrichment=enrich_formula, + do_picture_classification=enrich_picture_classes, document_timeout=document_timeout, ) pipeline_options.table_structure_options.do_cell_matching = ( diff --git a/docling/models/base_model.py b/docling/models/base_model.py index 08d728c..a2bc776 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -6,6 +6,7 @@ from typing_extensions import TypeVar from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.document import ConversionResult +from docling.datamodel.settings import settings class BasePageModel(ABC): @@ -21,6 +22,8 @@ EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) class GenericEnrichmentModel(ABC, Generic[EnrichElementT]): + elements_batch_size: int = settings.perf.elements_batch_size + @abstractmethod def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: pass diff --git a/docling/models/code_formula_model.py b/docling/models/code_formula_model.py index e4d5694..6648f46 100644 --- a/docling/models/code_formula_model.py +++ b/docling/models/code_formula_model.py @@ -61,6 +61,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel): Processes the given batch of elements and enriches them with predictions. """ + elements_batch_size = 5 images_scale = 1.66 # = 120 dpi, aligned with training data resolution expansion_factor = 0.03 diff --git a/docling/pipeline/base_pipeline.py b/docling/pipeline/base_pipeline.py index 89aedf8..1bf48ef 100644 --- a/docling/pipeline/base_pipeline.py +++ b/docling/pipeline/base_pipeline.py @@ -79,7 +79,7 @@ class BasePipeline(ABC): for model in self.enrichment_pipe: for element_batch in chunkify( _prepare_elements(conv_res, model), - settings.perf.elements_batch_size, + model.elements_batch_size, ): for element in model( doc=conv_res.document, element_batch=element_batch