Docling/docling/models/factories/base_factory.py
Michele Dolfi 5458a88464
ci: add coverage and ruff (#1383)
* add coverage calculation and push

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* new codecov version and usage of token

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* enable ruff formatter instead of black and isort

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* apply ruff lint fixes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* apply ruff unsafe fixes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add removed imports

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* runs 1 on linter issues

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* finalize linter fixes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* Update pyproject.toml

Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com>
2025-04-14 18:01:26 +02:00

123 lines
3.9 KiB
Python

import enum
import logging
from abc import ABCMeta
from typing import Generic, Optional, Type, TypeVar
from pluggy import PluginManager
from pydantic import BaseModel
from docling.datamodel.pipeline_options import BaseOptions
from docling.models.base_model import BaseModelWithOptions
A = TypeVar("A", bound=BaseModelWithOptions)
logger = logging.getLogger(__name__)
class FactoryMeta(BaseModel):
kind: str
plugin_name: str
module: str
class BaseFactory(Generic[A], metaclass=ABCMeta):
default_plugin_name = "docling"
def __init__(self, plugin_attr_name: str, plugin_name=default_plugin_name):
self.plugin_name = plugin_name
self.plugin_attr_name = plugin_attr_name
self._classes: dict[Type[BaseOptions], Type[A]] = {}
self._meta: dict[Type[BaseOptions], FactoryMeta] = {}
@property
def registered_kind(self) -> list[str]:
return [opt.kind for opt in self._classes.keys()]
def get_enum(self) -> enum.Enum:
return enum.Enum(
self.plugin_attr_name + "_enum",
names={kind: kind for kind in self.registered_kind},
type=str,
module=__name__,
)
@property
def classes(self):
return self._classes
@property
def registered_meta(self):
return self._meta
def create_instance(self, options: BaseOptions, **kwargs) -> A:
try:
_cls = self._classes[type(options)]
return _cls(options=options, **kwargs)
except KeyError:
raise RuntimeError(self._err_msg_on_class_not_found(options.kind))
def create_options(self, kind: str, *args, **kwargs) -> BaseOptions:
for opt_cls, _ in self._classes.items():
if opt_cls.kind == kind:
return opt_cls(*args, **kwargs)
raise RuntimeError(self._err_msg_on_class_not_found(kind))
def _err_msg_on_class_not_found(self, kind: str):
msg = []
for opt, cls in self._classes.items():
msg.append(f"\t{opt.kind!r} => {cls!r}")
msg_str = "\n".join(msg)
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str):
opt_type = cls.get_options_type()
if opt_type in self._classes:
raise ValueError(
f"{opt_type.kind!r} already registered to class {self._classes[opt_type]!r}"
)
self._classes[opt_type] = cls
self._meta[opt_type] = FactoryMeta(
kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name
)
def load_from_plugins(
self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False
):
plugin_name = plugin_name or self.plugin_name
plugin_manager = PluginManager(plugin_name)
plugin_manager.load_setuptools_entrypoints(plugin_name)
for plugin_name, plugin_module in plugin_manager.list_name_plugin():
plugin_module_name = str(plugin_module.__name__) # type: ignore
if not allow_external_plugins and not plugin_module_name.startswith(
"docling."
):
logger.warning(
f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false."
)
continue
attr = getattr(plugin_module, self.plugin_attr_name, None)
if callable(attr):
logger.info("Loading plugin %r", plugin_name)
config = attr()
self.process_plugin(config, plugin_name, plugin_module_name)
def process_plugin(self, config, plugin_name: str, plugin_module_name: str):
for item in config[self.plugin_attr_name]:
try:
self.register(item, plugin_name, plugin_module_name)
except ValueError:
logger.warning("%r already registered", item)