diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1cd08f2..531ec8d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -10,7 +10,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Install tesseract - run: sudo apt-get update && sudo apt-get install -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-fra tesseract-ocr-deu tesseract-ocr-spa libleptonica-dev libtesseract-dev pkg-config + run: sudo apt-get update && sudo apt-get install -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-fra tesseract-ocr-deu tesseract-ocr-spa tesseract-ocr-script-latn libleptonica-dev libtesseract-dev pkg-config - name: Set TESSDATA_PREFIX run: | echo "TESSDATA_PREFIX=$(dpkg -L tesseract-ocr-eng | grep tessdata$)" >> "$GITHUB_ENV" diff --git a/docling/models/tesseract_ocr_model.py b/docling/models/tesseract_ocr_model.py index b2bd358..6a1b60e 100644 --- a/docling/models/tesseract_ocr_model.py +++ b/docling/models/tesseract_ocr_model.py @@ -54,43 +54,56 @@ class TesseractOcrModel(BaseOcrModel): # Initialize the tesseractAPI _log.debug("Initializing TesserOCR: %s", tesseract_version) lang = "+".join(self.options.lang) + + self.script_readers: dict[str, tesserocr.PyTessBaseAPI] = {} + + if any([l.startswith("script/") for l in tesserocr_languages]): + self.script_prefix = "script/" + else: + self.script_prefix = "" + + tesserocr_kwargs = { + "psm": tesserocr.PSM.AUTO, + "init": True, + "oem": tesserocr.OEM.DEFAULT, + } + if self.options.path is not None: + tesserocr_kwargs["path"] = self.options.path + + if lang == "auto": self.reader = tesserocr.PyTessBaseAPI( - path=self.options.path, - lang=lang, - psm=tesserocr.PSM.AUTO, - init=True, - oem=tesserocr.OEM.DEFAULT, + **{"lang": "osd", "psm": tesserocr.PSM.OSD_ONLY} | tesserocr_kwargs ) else: self.reader = tesserocr.PyTessBaseAPI( - lang=lang, - psm=tesserocr.PSM.AUTO, - init=True, - oem=tesserocr.OEM.DEFAULT, + **{"lang": lang} | tesserocr_kwargs, ) + self.reader_RIL = tesserocr.RIL def __del__(self): if self.reader is not None: # Finalize the tesseractAPI self.reader.End() + for script in self.script_readers: + self.script_readers[script].End() def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: - if not self.enabled: yield from page_batch return + import tesserocr + for page in page_batch: assert page._backend is not None if not page._backend.is_valid(): yield page else: with TimeRecorder(conv_res, "ocr"): - assert self.reader is not None ocr_rects = self.get_ocr_rects(page) @@ -106,20 +119,55 @@ class TesseractOcrModel(BaseOcrModel): # Retrieve text snippets with their bounding boxes self.reader.SetImage(high_res_image) - boxes = self.reader.GetComponentImages( + + if self.options.lang == ["auto"]: + osd = self.reader.DetectOrientationScript() + + # No text, probably + if osd is None: + continue + + script = osd["script_name"] + + if script == "Katakana" or script == "Hiragana": + script = "Japanese" + elif script == "Han": + script = "HanS" + elif script == "Korean": + script = "Hangul" + + _log.debug( + f'Using model for the detected script "{script}"' + ) + + if script not in self.script_readers: + self.script_readers[script] = tesserocr.PyTessBaseAPI( + path=self.reader.GetDatapath(), + lang=f"{self.script_prefix}{script}", + psm=tesserocr.PSM.AUTO, + init=True, + oem=tesserocr.OEM.DEFAULT, + ) + + local_reader = self.script_readers[script] + local_reader.SetImage(high_res_image) + else: + local_reader = self.reader + + boxes = local_reader.GetComponentImages( self.reader_RIL.TEXTLINE, True ) cells = [] for ix, (im, box, _, _) in enumerate(boxes): # Set the area of interest. Tesseract uses Bottom-Left for the origin - self.reader.SetRectangle( + local_reader.SetRectangle( box["x"], box["y"], box["w"], box["h"] ) # Extract text within the bounding box - text = self.reader.GetUTF8Text().strip() - confidence = self.reader.MeanTextConf() + text = local_reader.GetUTF8Text().strip() + confidence = local_reader.MeanTextConf() left = box["x"] / self.scale bottom = box["y"] / self.scale right = (box["x"] + box["w"]) / self.scale diff --git a/tests/test_e2e_ocr_conversion.py b/tests/test_e2e_ocr_conversion.py index 73a943a..b3cdd31 100644 --- a/tests/test_e2e_ocr_conversion.py +++ b/tests/test_e2e_ocr_conversion.py @@ -60,6 +60,7 @@ def test_e2e_conversions(): RapidOcrOptions(), EasyOcrOptions(force_full_page_ocr=True), TesseractOcrOptions(force_full_page_ocr=True), + TesseractOcrOptions(force_full_page_ocr=True, lang=["auto"]), TesseractCliOcrOptions(force_full_page_ocr=True), RapidOcrOptions(force_full_page_ocr=True), ] @@ -70,7 +71,9 @@ def test_e2e_conversions(): engines.append(OcrMacOptions(force_full_page_ocr=True)) for ocr_options in engines: - print(f"Converting with ocr_engine: {ocr_options.kind}") + print( + f"Converting with ocr_engine: {ocr_options.kind}, language: {ocr_options.lang}" + ) converter = get_converter(ocr_options=ocr_options) for pdf_path in pdf_paths: print(f"converting {pdf_path}")