[init] initial commit

This commit is contained in:
fenghao.2019 2025-05-26 23:20:51 +08:00
commit 49f51871c6
31 changed files with 2757 additions and 0 deletions

154
.gitignore vendored Normal file
View File

@ -0,0 +1,154 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
coverage.xml
*.mo
*.pot
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
.idea/
*.iml
# VS Code
.vscode/
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
# macOS
.DS_Store
# Windows
Thumbs.db
ehthumbs.db
Desktop.ini
fusion_result.json
kernel_meta/

34
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,34 @@
repos:
# 1. isort - 自动排序 Python imports
- repo: https://github.com/pycqa/isort
rev: 6.0.1 # 使用固定版本号
hooks:
- id: isort
name: isort (python)
args: [--profile=black] # 与 Black 兼容的配置
language: python
# 2. Black - 自动格式化 Python 代码
- repo: https://github.com/psf/black
rev: 25.1.0 # 使用固定版本号
hooks:
- id: black
language: python
# 3. flake8 - Python 静态检查
- repo: https://github.com/pycqa/flake8
rev: 7.2.0
hooks:
- id: flake8
args: [--max-line-length=120, --ignore=E203] # 设置行长度为 120
additional_dependencies: [flake8-bugbear==24.12.12] # 可选:增强检查
# 4. pre-commit-hooks - 通用 Git 钩子
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace # 删除行尾空格
- id: end-of-file-fixer # 确保文件以换行符结束
- id: check-yaml # 验证 YAML 文件语法
- id: check-added-large-files # 阻止大文件提交
args: ["--maxkb=512"]

9
LICENSE Normal file
View File

@ -0,0 +1,9 @@
MIT License
Copyright 2025 ByteDance Ltd. and/or its affiliates
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

188
README.md Normal file
View File

@ -0,0 +1,188 @@
<div align="center">
<img src="./assets/dolphin.png" width="300">
</div>
<div align="center">
<a href="https://arxiv.org/abs/2505.14059">
<img src="https://img.shields.io/badge/Paper-arXiv-red">
</a>
<a href="https://huggingface.co/ByteDance/Dolphin">
<img src="https://img.shields.io/badge/HuggingFace-Dolphin-yellow">
</a>
<a href="http://115.190.42.15:8888/dolphin/">
<img src="https://img.shields.io/badge/Demo-Dolphin-blue">
</a>
<a href="https://github.com/bytedance/Dolphin">
<img src="https://img.shields.io/badge/Code-Github-green">
</a>
<a href="https://opensource.org/licenses/MIT">
<img src="https://img.shields.io/badge/License-MIT-lightgray">
</a>
<br>
</div>
<br>
<div align="center">
<img src="./assets/demo.gif" width="800">
</div>
# Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting
Dolphin (**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g) is a novel multimodal document image parsing model following an analyze-then-parse paradigm. This repository contains the demo code and pre-trained models for Dolphin.
## 📑 Overview
Document image parsing is challenging due to its complexly intertwined elements such as text paragraphs, figures, formulas, and tables. Dolphin addresses these challenges through a two-stage approach:
1. **🔍 Stage 1**: Comprehensive page-level layout analysis by generating element sequence in natural reading order
2. **🧩 Stage 2**: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts
<div align="center">
<img src="./assets/framework.png" width="680">
</div>
Dolphin achieves promising performance across diverse page-level and element-level parsing tasks while ensuring superior efficiency through its lightweight architecture and parallel parsing mechanism.
## 🚀 Demo
Try our demo on [Demo-Dolphin](http://115.190.42.15:8888/dolphin/).
## 📅 Changelog
- 🔥 **2025.05.21** Our demo is released at [link](http://115.190.42.15:8888/dolphin/). Check it out!
- 🔥 **2025.05.20** The pretrained model and inference code of Dolphin are released.
- 🔥 **2025.05.16** Our paper has been accepted by ACL 2025. Paper link: [arXiv](https://arxiv.org/abs/2505.14059).
## 🛠️ Installation
1. Clone the repository:
```bash
git clone https://github.com/ByteDance/Dolphin.git
cd Dolphin
```
2. Install the dependencies:
```bash
pip install -r requirements.txt
```
3. Download the pre-trained models using one of the following options:
**Option A: Original Model Format (config-based)**
Download from [Baidu Yun](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) or [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) and put them in the `./checkpoints` folder.
**Option B: Hugging Face Model Format**
Visit our Huggingface [model card](https://huggingface.co/ByteDance/Dolphin), or download model by:
```bash
# Download the model from Hugging Face Hub
git lfs install
git clone https://huggingface.co/ByteDance/Dolphin ./hf_model
# Or use the Hugging Face CLI
huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model
```
## ⚡ Inference
Dolphin provides two inference frameworks with support for two parsing granularities:
- **Page-level Parsing**: Parse the entire document image into a structured JSON and Markdown format
- **Element-level Parsing**: Parse individual document elements (text, table, formula)
### 📄 Page-level Parsing
#### Using Original Framework (config-based)
```bash
# Process a single document image
python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
# Process all document images in a directory
python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results
# Process with custom batch size for parallel element decoding
python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8
```
#### Using Hugging Face Framework
```bash
# Process a single document image
python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
# Process all document images in a directory
python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results
# Process with custom batch size for parallel element decoding
python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16
```
### 🧩 Element-level Parsing
#### Using Original Framework (config-based)
```bash
# Process a single table image
python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table
# Process a single formula image
python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
# Process a single text paragraph image
python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text
```
#### Using Hugging Face Framework
```bash
# Process a single table image
python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table
# Process a single formula image
python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
# Process a single text paragraph image
python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text
```
## 🌟 Key Features
- 🔄 Two-stage analyze-then-parse approach based on a single VLM
- 📊 Promising performance on document parsing tasks
- 🔍 Natural reading order element sequence generation
- 🧩 Heterogeneous anchor prompting for different document elements
- ⏱️ Efficient parallel parsing mechanism
- 🤗 Support for Hugging Face Transformers for easier integration
## 📮 Notice
**Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the issue. We are continuously working to optimize and improve the model.
## 💖 Acknowledgement
We would like to acknowledge the following open-source projects that provided inspiration and reference for this work:
- [Donut](https://github.com/clovaai/donut/)
- [Nougat](https://github.com/facebookresearch/nougat)
- [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0)
- [MinerU](https://github.com/opendatalab/MinerU/tree/master)
- [Swin](https://github.com/microsoft/Swin-Transformer)
- [Hugging Face Transformers](https://github.com/huggingface/transformers)
## 📝 Citation
If you find this code useful for your research, please use the following BibTeX entry.
```bibtex
@inproceedings{dolphin2025,
title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting},
author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and Tang, Jingqun and Liu, Hao and Huang, Can},
year={2025},
booktitle={Proceedings of the 65rd Annual Meeting of the Association for Computational Linguistics (ACL)}
}
```
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=bytedance/Dolphin&type=Date)](https://www.star-history.com/#bytedance/Dolphin&Date)

BIN
assets/demo.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 MiB

BIN
assets/dolphin.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

BIN
assets/framework.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

197
chat.py Normal file
View File

@ -0,0 +1,197 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import os
import warnings
from collections import OrderedDict
from omegaconf import ListConfig
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
import torch
from PIL import Image
from transformers import PreTrainedTokenizerFast
from utils.model import DonutConfig, DonutModel, SwinEncoder
from utils.processor import DolphinProcessor
def try_rename_lagacy_weights(ckpt, output_path=""):
if "state_dict" in ckpt.keys():
ckpt = ckpt["state_dict"]
if "module" in ckpt.keys():
ckpt = ckpt["module"]
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith("model."):
k = k[len("model.") :]
if k.startswith("encoder"):
new_ckpt["vpm" + k[len("encoder") :]] = v
elif k.startswith("decoder"):
new_ckpt["llm" + k[len("encoder") :]] = v
else:
new_ckpt[k] = v
if output_path:
torch.save(new_ckpt, output_path)
return new_ckpt
def convert_listconfig_to_list(config):
new_config = {}
for k, v in config.items():
if isinstance(v, ListConfig):
new_config[k] = list(v)
else:
new_config[k] = v
return new_config
class DOLPHIN:
def __init__(self, config, ckpt_path="") -> None:
self.model_args = config.model
self.swin_args = config.model.pop("swin_args")
self.swin_args = convert_listconfig_to_list(self.swin_args)
vision_tower = SwinEncoder(
input_size=self.swin_args["img_size"],
patch_size=self.swin_args["patch_size"],
embed_dim=self.swin_args["embed_dim"],
window_size=self.swin_args["window_size"],
encoder_layer=self.swin_args["encoder_layer"],
num_heads=self.swin_args["num_heads"],
align_long_axis=self.swin_args["align_long_axis"],
)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path)
self.tokenizer.pad_token = "<pad>"
self.tokenizer.bos_token = "<s>"
self.tokenizer.eos_token = "</s>"
self.tokenizer.unk_token = "<unk>"
if self.model_args.get("extra_answer_tokens", False):
# print("Allowing multitask training: adding <Answer/> to the tokenizer.")
prompt_end_token = " <Answer/>"
self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))})
self.tokenizer._prompt_end_token = prompt_end_token
self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token)
donut_config = DonutConfig(
decoder_layer=self.model_args.decoder_layer,
max_length=self.model_args.max_length,
max_position_embeddings=self.model_args.max_position_embeddings,
hidden_dimension=self.model_args.hidden_dimension,
)
self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer)
if self.model_args.model_name_or_path:
ckpt = torch.load(self.model_args.model_name_or_path)
ckpt = try_rename_lagacy_weights(ckpt)
self.model.load_state_dict(ckpt, strict=True)
self.model.to("cuda")
self.model.eval()
transform_args = {
"input_size": self.swin_args["img_size"],
"max_length": self.model_args.max_length,
}
self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args)
def chat(
self,
question,
image,
return_raw=False,
return_score=False,
return_img_size=False,
only_return_img_size=False,
max_batch_size=16,
):
def _preprocess_image(image):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
if return_img_size or only_return_img_size:
image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True)
else:
image_tensor = self.processor.process_image_for_inference(image, return_img_size=False)
ori_size = None
return image_tensor, ori_size
def _preprocess_prompt(question):
if self.model_args.get("extra_answer_tokens", False):
if self.tokenizer._prompt_end_token not in question:
question = question + self.tokenizer._prompt_end_token
prompt_ids = self.processor.process_prompt_for_inference(question)
return prompt_ids
def _preprocess_prompt_batch(question):
if self.model_args.get("extra_answer_tokens", False):
for i in range(len(question)):
if self.tokenizer._prompt_end_token not in question[i]:
question[i] = question[i] + self.tokenizer._prompt_end_token
if not question[i].startswith("<s>"):
question[i] = "<s>" + question[i]
return question
def _postprocess(output, question):
output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "")
if self.model_args.get("extra_answer_tokens", False):
output = output.split(self.tokenizer._prompt_end_token)[-1]
return output
if isinstance(question, list):
image_tensor_list = []
for i in image:
image_tensor, ori_size = _preprocess_image(i)
image_tensor_list.append(image_tensor)
image_tensor = torch.cat(image_tensor_list, dim=0)
question = _preprocess_prompt_batch(question)
self.processor.tokenizer.padding_side = "left"
prompt_ids = self.processor.tokenizer(
question, add_special_tokens=False, return_tensors="pt", padding=True
).input_ids
else:
image_tensor, ori_size = _preprocess_image(image)
prompt_ids = _preprocess_prompt(question)
if only_return_img_size:
return ori_size
model_output_batch = []
for i in range(0, image_tensor.shape[0], max_batch_size):
image_tensor_batch = image_tensor[i : i + max_batch_size]
prompt_ids_batch = prompt_ids[i : i + max_batch_size]
model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch)
model_output_batch.append(model_output)
model_output = {}
for k, v in model_output_batch[0].items():
if isinstance(v, torch.Tensor):
model_output[k] = sum(
[v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch],
[],
)
else:
model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], [])
if return_raw:
if return_img_size:
return model_output, ori_size
return model_output
else:
if isinstance(question, list):
output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))]
score = model_output["scores"]
else:
output = _postprocess(model_output["repetitions"][0], question)
score = model_output["scores"][0]
if return_score:
return output, score
if return_img_size:
return output, ori_size
return output

17
config/Dolphin.yaml Normal file
View File

@ -0,0 +1,17 @@
model:
model_name_or_path: "./checkpoints/dolphin_model.bin"
tokenizer_path: "./checkpoints/dolphin_tokenizer.json"
extra_answer_tokens: True # add <Answer/> token
max_length: 4096
decoder_layer: 10
max_position_embeddings: 4096
hidden_dimension: 1024
swin_args:
name: 'swin'
img_size: [896, 896]
patch_size: 4
embed_dim: 128
align_long_axis: False
window_size: 7
encoder_layer: [2, 2, 14, 2]
num_heads: [4, 8, 16, 32]

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 396 KiB

BIN
demo/page_imgs/page_1.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
demo/page_imgs/page_2.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
demo/page_imgs/page_3.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 439 KiB

BIN
demo/page_imgs/page_4.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 363 KiB

BIN
demo/page_imgs/page_5.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 626 KiB

129
demo_element.py Normal file
View File

@ -0,0 +1,129 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import argparse
import glob
import os
from omegaconf import OmegaConf
from PIL import Image
from chat import DOLPHIN
from utils.utils import *
def process_element(image_path, model, element_type, save_dir=None):
"""Process a single element image (text, table, formula)
Args:
image_path: Path to the element image
model: DOLPHIN model instance
element_type: Type of element ('text', 'table', 'formula')
save_dir: Directory to save results (default: same as input directory)
Returns:
Parsed content of the element and recognition results
"""
# Load and prepare image
pil_image = Image.open(image_path).convert("RGB")
pil_image = crop_margin(pil_image)
# Select appropriate prompt based on element type
if element_type == "table":
prompt = "Parse the table in the image."
label = "tab"
elif element_type == "formula":
prompt = "Read text in the image."
label = "formula"
else: # Default to text
prompt = "Read text in the image."
label = "text"
# Process the element
result = model.chat(prompt, pil_image)
# Create recognition result in the same format as the document parser
recognition_result = [
{
"label": label,
"text": result.strip(),
}
]
# Save results if save_dir is provided
if save_dir:
save_outputs(recognition_result, image_path, save_dir)
print(f"Results saved to {save_dir}")
return result, recognition_result
def main():
parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
parser.add_argument(
"--element_type",
type=str,
choices=["text", "table", "formula"],
default="text",
help="Type of element to process (text, table, formula)",
)
parser.add_argument(
"--save_dir",
type=str,
default=None,
help="Directory to save parsing results (default: same as input directory)",
)
parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
args = parser.parse_args()
# Load Model
config = OmegaConf.load(args.config)
model = DOLPHIN(config)
# Set save directory
save_dir = args.save_dir or (
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
)
setup_output_dirs(save_dir)
# Collect Images
if os.path.isdir(args.input_path):
image_files = []
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
image_files = sorted(image_files)
else:
if not os.path.exists(args.input_path):
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
image_files = [args.input_path]
total_samples = len(image_files)
print(f"\nTotal samples to process: {total_samples}")
# Process images one by one
for image_path in image_files:
print(f"\nProcessing {image_path}")
try:
result, recognition_result = process_element(
image_path=image_path,
model=model,
element_type=args.element_type,
save_dir=save_dir,
)
if args.print_results:
print("\nRecognition result:")
print(result)
print("-" * 40)
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
continue
if __name__ == "__main__":
main()

193
demo_element_hf.py Normal file
View File

@ -0,0 +1,193 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import argparse
import glob
import os
import torch
from PIL import Image
from transformers import AutoProcessor, VisionEncoderDecoderModel
from utils.utils import *
class DOLPHIN:
def __init__(self, model_id_or_path):
"""Initialize the Hugging Face model
Args:
model_id_or_path: Path to local model or Hugging Face model ID
"""
# Load model from local path or Hugging Face hub
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
self.model.eval()
# Set device and precision
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model = self.model.half() # Always use half precision by default
# set tokenizer
self.tokenizer = self.processor.tokenizer
def chat(self, prompt, image):
"""Process an image with the given prompt
Args:
prompt: Text prompt to guide the model
image: PIL Image to process
Returns:
Generated text from the model
"""
# Prepare image
pixel_values = self.processor(image, return_tensors="pt").pixel_values
pixel_values = pixel_values.half()
# Prepare prompt
prompt = f"<s>{prompt} <Answer/>"
prompt_ids = self.tokenizer(
prompt,
add_special_tokens=False,
return_tensors="pt"
).input_ids.to(self.device)
decoder_attention_mask = torch.ones_like(prompt_ids)
# Generate text
outputs = self.model.generate(
pixel_values=pixel_values.to(self.device),
decoder_input_ids=prompt_ids,
decoder_attention_mask=decoder_attention_mask,
min_length=1,
max_length=4096,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[self.tokenizer.unk_token_id]],
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
)
# Process the output
sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
sequence = sequence.replace(prompt, "").replace("<pad>", "").replace("</s>", "").strip()
return sequence
def process_element(image_path, model, element_type, save_dir=None):
"""Process a single element image (text, table, formula)
Args:
image_path: Path to the element image
model: HFModel model instance
element_type: Type of element ('text', 'table', 'formula')
save_dir: Directory to save results (default: same as input directory)
Returns:
Parsed content of the element and recognition results
"""
# Load and prepare image
pil_image = Image.open(image_path).convert("RGB")
pil_image = crop_margin(pil_image)
# Select appropriate prompt based on element type
if element_type == "table":
prompt = "Parse the table in the image."
label = "tab"
elif element_type == "formula":
prompt = "Read text in the image."
label = "formula"
else: # Default to text
prompt = "Read text in the image."
label = "text"
# Process the element
result = model.chat(prompt, pil_image)
# Create recognition result in the same format as the document parser
recognition_result = [
{
"label": label,
"text": result.strip(),
}
]
# Save results if save_dir is provided
if save_dir:
save_outputs(recognition_result, image_path, save_dir)
print(f"Results saved to {save_dir}")
return result, recognition_result
def main():
parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
parser.add_argument(
"--element_type",
type=str,
choices=["text", "table", "formula"],
default="text",
help="Type of element to process (text, table, formula)",
)
parser.add_argument(
"--save_dir",
type=str,
default=None,
help="Directory to save parsing results (default: same as input directory)",
)
parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
args = parser.parse_args()
# Load Model
model = DOLPHIN(args.model_path)
# Set save directory
save_dir = args.save_dir or (
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
)
setup_output_dirs(save_dir)
# Collect Images
if os.path.isdir(args.input_path):
image_files = []
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
image_files = sorted(image_files)
else:
if not os.path.exists(args.input_path):
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
image_files = [args.input_path]
total_samples = len(image_files)
print(f"\nTotal samples to process: {total_samples}")
# Process images one by one
for image_path in image_files:
print(f"\nProcessing {image_path}")
try:
result, recognition_result = process_element(
image_path=image_path,
model=model,
element_type=args.element_type,
save_dir=save_dir,
)
if args.print_results:
print("\nRecognition result:")
print(result)
print("-" * 40)
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
continue
if __name__ == "__main__":
main()

171
demo_page.py Normal file
View File

@ -0,0 +1,171 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import argparse
import glob
import os
import cv2
from omegaconf import OmegaConf
from PIL import Image
from chat import DOLPHIN
from utils.utils import *
def process_page(image_path, model, save_dir, max_batch_size):
"""Parse document images with two stages"""
# Stage 1: Page-level layout and reading order parsing
pil_image = Image.open(image_path).convert("RGB")
layout_output = model.chat("Parse the reading order of this document.", pil_image)
# Stage 2: Element-level content parsing
padded_image, dims = prepare_image(pil_image)
recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
# Save outputs
json_path = save_outputs(recognition_results, image_path, save_dir)
return json_path, recognition_results
def process_elements(layout_results, padded_image, dims, model, max_batch_size):
"""Parse all document elements with parallel decoding"""
layout_results = parse_layout_string(layout_results)
text_table_elements = [] # Elements that need processing
figure_results = [] # Figure elements (no processing needed)
previous_box = None
reading_order = 0
# Collect elements for processing
for bbox, label in layout_results:
try:
# Adjust coordinates
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
bbox, padded_image, dims, previous_box
)
# Crop and parse element
cropped = padded_image[y1:y2, x1:x2]
if cropped.size > 0:
if label == "fig":
# For figure regions, add empty text result immediately
figure_results.append(
{
"label": label,
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
"text": "",
"reading_order": reading_order,
}
)
else:
# For text or table regions, prepare for parsing
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
prompt = "Parse the table in the image." if label == "tab" else "Read text in the image."
text_table_elements.append(
{
"crop": pil_crop,
"prompt": prompt,
"label": label,
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
"reading_order": reading_order,
}
)
reading_order += 1
except Exception as e:
print(f"Error processing bbox with label {label}: {str(e)}")
continue
# Parse text/table elements in parallel
recognition_results = figure_results
if text_table_elements:
crops_list = [elem["crop"] for elem in text_table_elements]
prompts_list = [elem["prompt"] for elem in text_table_elements]
# Inference in batch
batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size)
# Add batch results to recognition_results
for i, result in enumerate(batch_results):
elem = text_table_elements[i]
recognition_results.append(
{
"label": elem["label"],
"bbox": elem["bbox"],
"text": result.strip(),
"reading_order": elem["reading_order"],
}
)
# Sort elements by reading order
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
return recognition_results
def main():
parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
parser.add_argument(
"--save_dir",
type=str,
default=None,
help="Directory to save parsing results (default: same as input directory)",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=4,
help="Maximum number of document elements to parse in a single batch (default: 4)",
)
args = parser.parse_args()
# Load Model
config = OmegaConf.load(args.config)
model = DOLPHIN(config)
# Collect Document Images
if os.path.isdir(args.input_path):
image_files = []
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
image_files = sorted(image_files)
else:
if not os.path.exists(args.input_path):
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
image_files = [args.input_path]
save_dir = args.save_dir or (
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
)
setup_output_dirs(save_dir)
total_samples = len(image_files)
print(f"\nTotal samples to process: {total_samples}")
# Process All Document Images
for image_path in image_files:
print(f"\nProcessing {image_path}")
try:
json_path, recognition_results = process_page(
image_path=image_path,
model=model,
save_dir=save_dir,
max_batch_size=args.max_batch_size,
)
print(f"Processing completed. Results saved to {save_dir}")
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
continue
if __name__ == "__main__":
main()

288
demo_page_hf.py Normal file
View File

@ -0,0 +1,288 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import argparse
import glob
import os
import cv2
import torch
from PIL import Image
from transformers import AutoProcessor, VisionEncoderDecoderModel
from utils.utils import *
class DOLPHIN:
def __init__(self, model_id_or_path):
"""Initialize the Hugging Face model
Args:
model_id_or_path: Path to local model or Hugging Face model ID
"""
# Load model from local path or Hugging Face hub
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
self.model.eval()
# Set device and precision
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model = self.model.half() # Always use half precision by default
# set tokenizer
self.tokenizer = self.processor.tokenizer
def chat(self, prompt, image):
"""Process an image or batch of images with the given prompt(s)
Args:
prompt: Text prompt or list of prompts to guide the model
image: PIL Image or list of PIL Images to process
Returns:
Generated text or list of texts from the model
"""
# Check if we're dealing with a batch
is_batch = isinstance(image, list)
if not is_batch:
# Single image, wrap it in a list for consistent processing
images = [image]
prompts = [prompt]
else:
# Batch of images
images = image
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
# Prepare image
batch_inputs = self.processor(images, return_tensors="pt", padding=True)
batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
# Prepare prompt
prompts = [f"<s>{p} <Answer/>" for p in prompts]
batch_prompt_inputs = self.tokenizer(
prompts,
add_special_tokens=False,
return_tensors="pt"
)
batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
# Generate text
outputs = self.model.generate(
pixel_values=batch_pixel_values,
decoder_input_ids=batch_prompt_ids,
decoder_attention_mask=batch_attention_mask,
min_length=1,
max_length=4096,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[self.tokenizer.unk_token_id]],
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
repetition_penalty=1.1
)
# Process output
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
# Clean prompt text from output
results = []
for i, sequence in enumerate(sequences):
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
results.append(cleaned)
# Return a single result for single image input
if not is_batch:
return results[0]
return results
def process_page(image_path, model, save_dir, max_batch_size=None):
"""Parse document images with two stages"""
# Stage 1: Page-level layout and reading order parsing
pil_image = Image.open(image_path).convert("RGB")
layout_output = model.chat("Parse the reading order of this document.", pil_image)
# Stage 2: Element-level content parsing
padded_image, dims = prepare_image(pil_image)
recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
# Save outputs
json_path = save_outputs(recognition_results, image_path, save_dir)
return json_path, recognition_results
def process_elements(layout_results, padded_image, dims, model, max_batch_size=None):
"""Parse all document elements with parallel decoding"""
layout_results = parse_layout_string(layout_results)
# Store text and table elements separately
text_elements = [] # Text elements
table_elements = [] # Table elements
figure_results = [] # Image elements (no processing needed)
previous_box = None
reading_order = 0
# Collect elements to process and group by type
for bbox, label in layout_results:
try:
# Adjust coordinates
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
bbox, padded_image, dims, previous_box
)
# Crop and parse element
cropped = padded_image[y1:y2, x1:x2]
if cropped.size > 0:
if label == "fig":
# For figure regions, add empty text result immediately
figure_results.append(
{
"label": label,
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
"text": "",
"reading_order": reading_order,
}
)
else:
# Prepare element for parsing
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
element_info = {
"crop": pil_crop,
"label": label,
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
"reading_order": reading_order,
}
# Group by type
if label == "tab":
table_elements.append(element_info)
else: # Text elements
text_elements.append(element_info)
reading_order += 1
except Exception as e:
print(f"Error processing bbox with label {label}: {str(e)}")
continue
# Initialize results list
recognition_results = figure_results.copy()
# Process text elements (in batches)
if text_elements:
text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
recognition_results.extend(text_results)
# Process table elements (in batches)
if table_elements:
table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
recognition_results.extend(table_results)
# Sort elements by reading order
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
return recognition_results
def process_element_batch(elements, model, prompt, max_batch_size=None):
"""Process elements of the same type in batches"""
results = []
# Determine batch size
batch_size = len(elements)
if max_batch_size is not None and max_batch_size > 0:
batch_size = min(batch_size, max_batch_size)
# Process in batches
for i in range(0, len(elements), batch_size):
batch_elements = elements[i:i+batch_size]
crops_list = [elem["crop"] for elem in batch_elements]
# Use the same prompt for all elements in the batch
prompts_list = [prompt] * len(crops_list)
# Batch inference
batch_results = model.chat(prompts_list, crops_list)
# Add results
for j, result in enumerate(batch_results):
elem = batch_elements[j]
results.append({
"label": elem["label"],
"bbox": elem["bbox"],
"text": result.strip(),
"reading_order": elem["reading_order"],
})
return results
def main():
parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
parser.add_argument(
"--save_dir",
type=str,
default=None,
help="Directory to save parsing results (default: same as input directory)",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=16,
help="Maximum number of document elements to parse in a single batch (default: 16)",
)
args = parser.parse_args()
# Load Model
model = DOLPHIN(args.model_path)
# Collect Document Images
if os.path.isdir(args.input_path):
image_files = []
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
image_files = sorted(image_files)
else:
if not os.path.exists(args.input_path):
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
image_files = [args.input_path]
save_dir = args.save_dir or (
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
)
setup_output_dirs(save_dir)
total_samples = len(image_files)
print(f"\nTotal samples to process: {total_samples}")
# Process All Document Images
for image_path in image_files:
print(f"\nProcessing {image_path}")
try:
json_path, recognition_results = process_page(
image_path=image_path,
model=model,
save_dir=save_dir,
max_batch_size=args.max_batch_size,
)
print(f"Processing completed. Results saved to {save_dir}")
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
continue
if __name__ == "__main__":
main()

16
pyproject.toml Normal file
View File

@ -0,0 +1,16 @@
[tool.black]
line-length = 120
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
albumentations==1.4.0
numpy==1.24.4
omegaconf==2.3.0
opencv-python==4.11.0.86
opencv-python-headless==4.5.5.64
pillow==9.3.0
timm==0.5.4
torch==2.1.0
torchvision==0.16.0
transformers==4.47.0
accelerate==1.6.0

442
utils/markdown_utils.py Normal file
View File

@ -0,0 +1,442 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import re
import base64
from typing import List, Dict, Any, Optional
"""
Example input:
[
{"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6},
{"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7},
{"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8},
{"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9},
{"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
]
"""
def extract_table_from_html(html_string):
"""Extract and clean table tags from HTML string"""
try:
table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
tables = table_pattern.findall(html_string)
tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
return '\n'.join(tables)
except Exception as e:
print(f"extract_table_from_html error: {str(e)}")
return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
class MarkdownConverter:
"""Convert structured recognition results to Markdown format"""
def __init__(self):
# Define heading levels for different section types
self.heading_levels = {
'title': '#',
'sec': '##',
'sub_sec': '###'
}
# Define which labels need special handling
self.special_labels = {
'tab', 'fig', 'title', 'sec', 'sub_sec',
'list', 'formula', 'reference', 'alg'
}
def try_remove_newline(self, text: str) -> str:
try:
# Preprocess text to handle line breaks
text = text.strip()
text = text.replace('-\n', '')
# Handle Chinese text line breaks
def is_chinese(char):
return '\u4e00' <= char <= '\u9fff'
lines = text.split('\n')
processed_lines = []
# Process all lines except the last one
for i in range(len(lines)-1):
current_line = lines[i].strip()
next_line = lines[i+1].strip()
# Always add the current line, but determine if we need a newline
if current_line: # If current line is not empty
if next_line: # If next line is not empty
# For Chinese text handling
if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
processed_lines.append(current_line)
else:
processed_lines.append(current_line + ' ')
else:
# Next line is empty, add current line with newline
processed_lines.append(current_line + '\n')
else:
# Current line is empty, add an empty line
processed_lines.append('\n')
# Add the last line
if lines and lines[-1].strip():
processed_lines.append(lines[-1].strip())
text = ''.join(processed_lines)
return text
except Exception as e:
print(f"try_remove_newline error: {str(e)}")
return text # Return original text on error
def _handle_text(self, text: str) -> str:
"""
Process regular text content, preserving paragraph structure
"""
try:
if not text:
return ""
if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
text = "$$" + text + "$$"
elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
text = "$" + text + "$"
# Process formulas in text before handling other text processing
text = self._process_formulas_in_text(text)
text = self.try_remove_newline(text)
# Return processed text
return text
except Exception as e:
print(f"_handle_text error: {str(e)}")
return text # Return original text on error
def _process_formulas_in_text(self, text: str) -> str:
"""
Process mathematical formulas in text by iteratively finding and replacing formulas.
- Identify inline and block formulas
- Replace newlines within formulas with \\
"""
try:
# Define formula delimiters and their corresponding patterns
delimiters = [
('$$', '$$'), # Block formula with $$
('\\[', '\\]'), # Block formula with \[ \]
('$', '$'), # Inline formula with $
('\\(', '\\)') # Inline formula with \( \)
]
# Process the text by iterating through each delimiter type
result = text
for start_delim, end_delim in delimiters:
# Create a pattern that matches from start to end delimiter
# Using a custom approach to avoid issues with nested delimiters
current_pos = 0
processed_parts = []
while current_pos < len(result):
# Find the next start delimiter
start_pos = result.find(start_delim, current_pos)
if start_pos == -1:
# No more formulas of this type
processed_parts.append(result[current_pos:])
break
# Add text before the formula
processed_parts.append(result[current_pos:start_pos])
# Find the matching end delimiter
end_pos = result.find(end_delim, start_pos + len(start_delim))
if end_pos == -1:
# No matching end delimiter, treat as regular text
processed_parts.append(result[start_pos:])
break
# Extract the formula content (without delimiters)
formula_content = result[start_pos + len(start_delim):end_pos]
# Process the formula content - replace newlines with \\
processed_formula = formula_content.replace('\n', ' \\\\ ')
# Add the processed formula with its delimiters
processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
# Move past this formula
current_pos = end_pos + len(end_delim)
# Update the result with processed text
result = ''.join(processed_parts)
return result
except Exception as e:
print(f"_process_formulas_in_text error: {str(e)}")
return text # Return original text on error
def _remove_newline_in_heading(self, text: str) -> str:
"""
Remove newline in heading
"""
try:
# Handle Chinese text line breaks
def is_chinese(char):
return '\u4e00' <= char <= '\u9fff'
# Check if the text contains Chinese characters
if any(is_chinese(char) for char in text):
return text.replace('\n', '')
else:
return text.replace('\n', ' ')
except Exception as e:
print(f"_remove_newline_in_heading error: {str(e)}")
return text
def _handle_heading(self, text: str, label: str) -> str:
"""
Convert section headings to appropriate markdown format
"""
try:
level = self.heading_levels.get(label, '#')
text = text.strip()
text = self._remove_newline_in_heading(text)
text = self._handle_text(text)
return f"{level} {text}\n\n"
except Exception as e:
print(f"_handle_heading error: {str(e)}")
return f"# Error processing heading: {text}\n\n"
def _handle_list_item(self, text: str) -> str:
"""
Convert list items to markdown list format
"""
try:
return f"- {text.strip()}\n"
except Exception as e:
print(f"_handle_list_item error: {str(e)}")
return f"- Error processing list item: {text}\n"
def _handle_figure(self, text: str, section_count: int) -> str:
"""
Convert base64 encoded image to markdown image syntax
"""
try:
# Determine image format (assuming PNG if not specified)
img_format = "png"
if text.startswith("data:image/"):
# Extract format from data URI
img_format = text.split(";")[0].split("/")[1]
elif ";" in text and "," in text:
# Already in data URI format
return f"![Figure {section_count}]({text})\n\n"
else:
# Raw base64, convert to data URI
data_uri = f"data:image/{img_format};base64,{text}"
return f"![Figure {section_count}]({data_uri})\n\n"
except Exception as e:
print(f"_handle_figure error: {str(e)}")
return f"*[Error processing figure: {str(e)}]*\n\n"
def _handle_table(self, text: str) -> str:
"""
Convert table content to markdown format
"""
try:
markdown_content = []
if '<table' in text.lower() or '<tr' in text.lower():
markdown_table = extract_table_from_html(text)
markdown_content.append(markdown_table + "\n")
else:
table_lines = text.split('\n')
if table_lines:
col_count = len(table_lines[0].split()) if table_lines[0] else 1
header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
markdown_content.append(header)
markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |')
for line in table_lines[1:]:
cells = line.split()
while len(cells) < col_count:
cells.append('')
markdown_content.append('| ' + ' | '.join(cells) + ' |')
return '\n'.join(markdown_content) + '\n\n'
except Exception as e:
print(f"_handle_table error: {str(e)}")
return f"*[Error processing table: {str(e)}]*\n\n"
def _handle_algorithm(self, text: str) -> str:
"""
Process algorithm blocks with proper formatting
"""
try:
# Remove algorithm environment tags if present
text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
# Process the algorithm text
lines = text.strip().split('\n')
# Check if there's a caption or label
caption = ""
algorithm_text = []
for line in lines:
if '\\caption' in line:
# Extract caption text
caption_match = re.search(r'\\caption\{(.*?)\}', line)
if caption_match:
caption = f"**{caption_match.group(1)}**\n\n"
continue
elif '\\label' in line:
continue # Skip label lines
else:
algorithm_text.append(line)
# Join the algorithm text and wrap in code block
formatted_text = '\n'.join(algorithm_text)
# Return the formatted algorithm with caption
return f"{caption}```\n{formatted_text}\n```\n\n"
except Exception as e:
print(f"_handle_algorithm error: {str(e)}")
return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
def _handle_formula(self, text: str) -> str:
"""
Handle formula-specific content
"""
try:
# Process the formula content
processed_text = self._process_formulas_in_text(text)
# For formula blocks, ensure they're properly formatted in markdown
if '$$' not in processed_text and '\\[' not in processed_text:
# If no block formula delimiters are present, wrap in $$ for block formula
processed_text = f'$${processed_text}$$'
return f"{processed_text}\n\n"
except Exception as e:
print(f"_handle_formula error: {str(e)}")
return f"*[Error processing formula: {str(e)}]*\n\n"
def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
"""
Convert recognition results to markdown format
"""
try:
markdown_content = []
for section_count, result in enumerate(recognition_results):
try:
label = result.get('label', '')
text = result.get('text', '').strip()
# Skip empty text
if not text:
continue
# Handle different content types
if label in {'title', 'sec', 'sub_sec'}:
markdown_content.append(self._handle_heading(text, label))
elif label == 'list':
markdown_content.append(self._handle_list_item(text))
elif label == 'fig':
markdown_content.append(self._handle_figure(text, section_count))
elif label == 'tab':
markdown_content.append(self._handle_table(text))
elif label == 'alg':
markdown_content.append(self._handle_algorithm(text))
elif label == 'formula':
markdown_content.append(self._handle_formula(text))
elif label not in self.special_labels:
# Handle regular text (paragraphs, etc.)
processed_text = self._handle_text(text)
markdown_content.append(f"{processed_text}\n\n")
except Exception as e:
print(f"Error processing item {section_count}: {str(e)}")
# Add a placeholder for the failed item
markdown_content.append(f"*[Error processing content]*\n\n")
# Join all content and apply post-processing
result = ''.join(markdown_content)
return self._post_process(result)
except Exception as e:
print(f"convert error: {str(e)}")
return f"Error generating markdown content: {str(e)}"
def _post_process(self, markdown_content: str) -> str:
"""
Apply post-processing fixes to the generated markdown content
"""
try:
# Handle author information
author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
def process_author_match(match):
# Extract author content
author_content = match.group(1)
# Process the author content
return self._handle_text(author_content)
# Replace \author{...} with processed content
markdown_content = author_pattern.sub(process_author_match, markdown_content)
# Handle special case where author is inside math environment
math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL)
match = math_author_pattern.search(markdown_content)
if match:
# Extract the author command
author_cmd = match.group(1)
# Extract content from author command
author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
if author_content_match:
# Get author content and process it
author_content = author_content_match.group(1)
processed_content = self._handle_text(author_content)
# Replace the entire $\author{...}$ block with processed content
markdown_content = markdown_content.replace(match.group(0), processed_content)
# Replace LaTeX abstract environment with plain text
markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}',
r'**Abstract** \1',
markdown_content,
flags=re.DOTALL)
# Replace standalone \begin{abstract} (without matching end)
markdown_content = re.sub(r'\\begin\{abstract\}',
r'**Abstract**',
markdown_content)
# Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
r'\\tag{\1}',
markdown_content)
# Find the starting tag of the formula
markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\")
# Find the ending tag of the formula (ensure this is the only ending tag)
markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$")
# Fix other common LaTeX issues
replacements = [
# Fix spacing issues in subscripts and superscripts
(r'_ {', r'_{'),
(r'^ {', r'^{'),
# Fix potential issues with multiple consecutive newlines
(r'\n{3,}', r'\n\n')
]
for old, new in replacements:
markdown_content = re.sub(old, new, markdown_content)
return markdown_content
except Exception as e:
print(f"_post_process error: {str(e)}")
return markdown_content # Return original content if post-processing fails

477
utils/model.py Normal file
View File

@ -0,0 +1,477 @@
"""
Copyright (c) 2022-present NAVER Corp.
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
MIT License
This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
This modified file is released under the same license.
"""
import logging
from collections import defaultdict
from typing import List, Optional
import torch
import torch.nn.functional as F
from PIL import Image
from timm.models.swin_transformer import SwinTransformer
from torch import nn
from transformers import (
MBartConfig,
MBartForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers.file_utils import ModelOutput
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
class SwinEncoder(nn.Module):
r"""
Encoder based on SwinTransformer
Set the initial weights and configuration with a pretrained SwinTransformer and then
modify the detailed configurations
Args:
input_size: Input image size (width, height)
align_long_axis: Whether to rotate image if height is greater than width
window_size: Window size(=patch size) of SwinTransformer
encoder_layer: Number of layers of SwinTransformer encoder
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
"""
def __init__(
self,
input_size,
align_long_axis: bool = False,
window_size: int = 7,
encoder_layer: List[int] = [2, 2, 14, 2],
patch_size: int = [4, 4],
embed_dim: int = 128,
num_heads: List[int] = [4, 8, 16, 32],
):
super().__init__()
if isinstance(input_size, int):
input_size = [input_size, input_size]
self.input_size = input_size
self.align_long_axis = align_long_axis
self.window_size = window_size
self.encoder_layer = encoder_layer
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.model = SwinTransformer(
img_size=self.input_size,
depths=self.encoder_layer,
window_size=self.window_size,
patch_size=self.patch_size,
embed_dim=self.embed_dim,
num_heads=self.num_heads,
num_classes=0,
)
def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: (batch_size, num_channels, height, width)
"""
x = self.model.patch_embed(x)
x = self.model.pos_drop(x)
x = self.model.layers(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def _set_dtype(self, dtype):
self._dtype = dtype
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(dtype=self._dtype))
return ret.type(orig_type)
class BARTDecoder(nn.Module):
"""
Decoder based on Multilingual BART
Set the initial weights and configuration with a pretrained multilingual BART model,
and modify the detailed configurations as a Donut decoder
Args:
decoder_layer:
Number of layers of BARTDecoder
max_position_embeddings:
The maximum sequence length to be trained
name_or_path:
Name of a pretrained model name either registered in huggingface.co. or saved in local,
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
"""
def __init__(
self,
tokenizer,
decoder_layer: int,
max_position_embeddings: int,
hidden_dimension: int = 1024,
**kwargs,
):
super().__init__()
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_position_embeddings
self.hidden_dimension = hidden_dimension
self.tokenizer = tokenizer
self.model = MBartForCausalLM(
config=MBartConfig(
tie_word_embeddings=True,
is_decoder=True,
is_encoder_decoder=False,
add_cross_attention=True,
decoder_layers=self.decoder_layer,
max_position_embeddings=self.max_position_embeddings,
vocab_size=len(self.tokenizer),
scale_embedding=True,
add_final_layer_norm=True,
d_model=self.hidden_dimension,
)
)
# self.model.config.is_encoder_decoder = True # to get cross-attention
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
def add_special_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def add_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def prepare_inputs_for_inference(
self,
input_ids: torch.Tensor,
encoder_outputs: torch.Tensor,
past=None,
past_key_values=None,
use_cache: bool = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
"""
Args:
input_ids: (batch_size, sequence_length)
Returns:
input_ids: (batch_size, sequence_length)
attention_mask: (batch_size, sequence_length)
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
"""
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
past = past or past_key_values
if past is not None:
input_ids = input_ids[:, -1:]
output = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
"encoder_hidden_states": encoder_outputs.last_hidden_state,
}
return output
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: bool = None,
output_attentions: Optional[torch.Tensor] = None,
output_hidden_states: Optional[torch.Tensor] = None,
return_dict: bool = None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@staticmethod
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
"""
Resize position embeddings
Truncate if sequence length of MBart backbone is greater than given max_length,
else interpolate to max_length
"""
if weight.shape[0] > max_length:
weight = weight[:max_length, ...]
else:
weight = (
F.interpolate(
weight.permute(1, 0).unsqueeze(0),
size=max_length,
mode="linear",
align_corners=False,
)
.squeeze(0)
.permute(1, 0)
)
return weight
class DonutConfig(PretrainedConfig):
def __init__(
self,
decoder_layer: int = 10,
max_position_embeddings: int = None,
max_length: int = 4096,
hidden_dimension: int = 1024,
**kwargs,
):
super().__init__()
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
self.max_length = max_length
self.hidden_dimension = hidden_dimension
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
def batch(l, b=15):
subs = []
for i in range(len(l) - b):
subs.append(l[i : i + b])
return subs
def subdiv(l, b=10):
subs = []
for i in range(len(l) - b):
subs.append(l[: i + b])
return subs
class DonutModel(PreTrainedModel):
config_class = DonutConfig
base_model_prefix = "donut"
def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
super().__init__(config)
self.config = config
self.tokenizer = tokenizer
self.vpm = vision_tower
# build language model
self.llm = BARTDecoder(
tokenizer=tokenizer,
decoder_layer=self.config.decoder_layer,
max_position_embeddings=self.config.max_position_embeddings,
hidden_dimension=self.config.hidden_dimension,
)
self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
def get_input_embeddings(self, tensor):
return self.llm.model.get_input_embeddings()(tensor)
def forward(
self,
inputs: dict,
):
image_tensors = inputs["pixel_values"]
input_ids = inputs["input_ids"].contiguous()
attention_mask = inputs["attention_mask"]
labels = inputs["labels"].contiguous()
encoder_outputs = self.vpm(
image_tensors,
text_embedding=self.llm.model.get_input_embeddings()(input_ids),
)
decoder_outputs = self.llm(
input_ids=input_ids,
encoder_hidden_states=encoder_outputs,
attention_mask=attention_mask,
labels=labels,
)
return decoder_outputs
def get_hidden_states_during_inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
):
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
all_hidden_states = self.vpm.forward_features(
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
)
return all_hidden_states
def get_attn_weights_during_inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
):
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
last_attn_score = self.vpm.get_last_layer_cross_attn_score(
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
)
return last_attn_score
def inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
return_attentions: bool = False,
early_stopping: bool = True,
):
"""
Generate a token sequence in an auto-regressive manner.
Args:
image: input document image (PIL.Image)
image_tensors: (1, num_channels, height, width)
convert prompt to tensor if image_tensor is not fed
"""
output = {
"predictions": list(),
"sequences": list(),
"repeats": list(),
"repetitions": list(),
}
if image is None and image_tensors is None:
logging.warn("Image not found")
return output
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
if len(encoder_outputs.last_hidden_state.size()) == 1:
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
# get decoder output
decoder_output = self.llm.model.generate(
input_ids=prompt_ids,
encoder_outputs=encoder_outputs,
min_length=1,
max_length=self.config.max_length,
pad_token_id=self.llm.tokenizer.pad_token_id,
eos_token_id=self.llm.tokenizer.eos_token_id,
use_cache=True,
return_dict_in_generate=True,
output_scores=True,
output_attentions=return_attentions,
do_sample=False,
num_beams=1,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
)
output["repetitions"] = decoder_output.sequences.clone()
output["sequences"] = decoder_output.sequences.clone()
output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
return output

64
utils/processor.py Normal file
View File

@ -0,0 +1,64 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import numpy as np
import torch
from PIL import ImageOps
from utils.utils import *
class DolphinProcessor:
def __init__(
self,
dp_config,
tokenizer,
**kwargs,
) -> None:
self.tokenizer = tokenizer
transform_args = kwargs.get("transform_args", {})
self.max_length = transform_args.get("max_length", 2048)
self.input_size = transform_args.get("input_size", [896, 896]) # height, width
if isinstance(self.input_size, int):
self.input_size = [self.input_size, self.input_size]
try:
self.answer_start_token = self.tokenizer._prompt_end_token
except AttributeError as err:
print('No answer_start_token found, use "" instead')
self.answer_start_token = ""
self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
def process_prompt_for_inference(self, prompt):
prompt = prompt.replace("<image>\n", "")
if not prompt.startswith("<s>"):
prompt = "<s>" + prompt
message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
return ids.unsqueeze(0)
def process_image_for_inference(self, image, return_img_size=False):
image = resize(image, min(self.input_size))
image.thumbnail((self.input_size[1], self.input_size[0]))
origin_w, origin_h = image.size
delta_width = self.input_size[1] - image.width
delta_height = self.input_size[0] - image.height
pad_width = delta_width // 2
pad_height = delta_height // 2
padding = (
pad_width,
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
image = ImageOps.expand(image, padding)
if return_img_size:
return test_transform(image).unsqueeze(0), (origin_w, origin_h)
return test_transform(image).unsqueeze(0)

367
utils/utils.py Normal file
View File

@ -0,0 +1,367 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import copy
import json
import os
import re
from dataclasses import dataclass
from typing import List, Tuple
import albumentations as alb
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms.functional import resize
from utils.markdown_utils import MarkdownConverter
def alb_wrapper(transform):
def f(im):
return transform(image=np.asarray(im))["image"]
return f
test_transform = alb_wrapper(
alb.Compose(
[
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
)
)
def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
# print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
if x2 <= x1 or y2 <= y1:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
if x1 < 0 or y1 < 0:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
if not abs_coord:
if x2 > 1 or y2 > 1:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
elif image_size is not None: # has image size
if x2 > image_size[0] or y2 > image_size[1]:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
return True, None
def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
"""
Image: cv2.image object, or Path
Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
"""
if isinstance(image, str):
image = cv2.imread(image)
img_h, img_w = image.shape[:2]
new_boxes = []
for box in boxes:
best_box = copy.deepcopy(box)
def check_edge(img, current_box, i, is_vertical):
edge = current_box[i]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
if is_vertical:
line = binary[current_box[1] : current_box[3] + 1, edge]
else:
line = binary[edge, current_box[0] : current_box[2] + 1]
transitions = np.abs(np.diff(line))
return np.sum(transitions) / len(transitions)
# Only widen the box
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
current_box = copy.deepcopy(box)
# make sure the box is within the image
current_box[0] = min(max(current_box[0], 0), img_w - 1)
current_box[1] = min(max(current_box[1], 0), img_h - 1)
current_box[2] = min(max(current_box[2], 0), img_w - 1)
current_box[3] = min(max(current_box[3], 0), img_h - 1)
for i, direction, is_vertical in edges:
best_score = check_edge(image, current_box, i, is_vertical)
if best_score <= threshold:
continue
for step in range(max_pixels):
current_box[i] += direction
if i == 0 or i == 2:
current_box[i] = min(max(current_box[i], 0), img_w - 1)
else:
current_box[i] = min(max(current_box[i], 0), img_h - 1)
score = check_edge(image, current_box, i, is_vertical)
if score < best_score:
best_score = score
best_box = copy.deepcopy(current_box)
if score <= threshold:
break
new_boxes.append(best_box)
return new_boxes
def parse_layout_string(bbox_str):
"""Parse layout string using regular expressions"""
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
matches = re.finditer(pattern, bbox_str)
parsed_results = []
for match in matches:
coords = [float(match.group(i)) for i in range(1, 5)]
label = match.group(5).strip()
parsed_results.append((coords, label))
return parsed_results
@dataclass
class ImageDimensions:
"""Class to store image dimensions"""
original_w: int
original_h: int
padded_w: int
padded_h: int
def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
"""Map coordinates from padded image back to original image
Args:
x1, y1, x2, y2: Coordinates in padded image
dims: Image dimensions object
Returns:
tuple: (x1, y1, x2, y2) coordinates in original image
"""
try:
# Calculate padding offsets
top = (dims.padded_h - dims.original_h) // 2
left = (dims.padded_w - dims.original_w) // 2
# Map back to original coordinates
orig_x1 = max(0, x1 - left)
orig_y1 = max(0, y1 - top)
orig_x2 = min(dims.original_w, x2 - left)
orig_y2 = min(dims.original_h, y2 - top)
# Ensure we have a valid box (width and height > 0)
if orig_x2 <= orig_x1:
orig_x2 = min(orig_x1 + 1, dims.original_w)
if orig_y2 <= orig_y1:
orig_y2 = min(orig_y1 + 1, dims.original_h)
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
except Exception as e:
print(f"map_to_original_coordinates error: {str(e)}")
# Return safe coordinates
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
"""
From absolute coordinates to relevant coordinates
e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
"""
try:
x1, y1, x2, y2 = abs_coords
return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3)
except Exception as e:
print(f"map_to_relevant_coordinates error: {str(e)}")
return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
"""Process and adjust coordinates
Args:
coords: Normalized coordinates [x1, y1, x2, y2]
padded_image: Padded image
dims: Image dimensions object
previous_box: Previous box coordinates for overlap adjustment
Returns:
tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
"""
try:
# Convert normalized coordinates to absolute coordinates
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
# Ensure coordinates are within image bounds before adjustment
x1 = max(0, min(x1, dims.padded_w - 1))
y1 = max(0, min(y1, dims.padded_h - 1))
x2 = max(0, min(x2, dims.padded_w))
y2 = max(0, min(y2, dims.padded_h))
# Ensure width and height are at least 1 pixel
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
# Extend box boundaries
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
x1, y1, x2, y2 = new_boxes[0]
# Ensure coordinates are still within image bounds after adjustment
x1 = max(0, min(x1, dims.padded_w - 1))
y1 = max(0, min(y1, dims.padded_h - 1))
x2 = max(0, min(x2, dims.padded_w))
y2 = max(0, min(y2, dims.padded_h))
# Ensure width and height are at least 1 pixel after adjustment
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
# Check for overlap with previous box and adjust
if previous_box is not None:
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
y1 = prev_y2
# Ensure y1 is still valid
y1 = min(y1, dims.padded_h - 1)
# Make sure y2 is still greater than y1
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
# Update previous box
new_previous_box = [x1, y1, x2, y2]
# Map to original coordinates
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
x1, y1, x2, y2, dims
)
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
except Exception as e:
print(f"process_coordinates error: {str(e)}")
# Return safe values
orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
"""Load and prepare image with padding while maintaining aspect ratio
Args:
image: PIL image
Returns:
tuple: (padded_image, image_dimensions)
"""
try:
# Convert PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
original_h, original_w = image.shape[:2]
# Calculate padding to make square image
max_size = max(original_h, original_w)
top = (max_size - original_h) // 2
bottom = max_size - original_h - top
left = (max_size - original_w) // 2
right = max_size - original_w - left
# Apply padding
padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=(0, 0, 0))
padded_h, padded_w = padded_image.shape[:2]
dimensions = ImageDimensions(
original_w=original_w,
original_h=original_h,
padded_w=padded_w,
padded_h=padded_h
)
return padded_image, dimensions
except Exception as e:
print(f"prepare_image error: {str(e)}")
# Create a minimal valid image and dimensions
h, w = image.height, image.width
dimensions = ImageDimensions(
original_w=w,
original_h=h,
padded_w=w,
padded_h=h
)
# Return a black image of the same size
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
def setup_output_dirs(save_dir):
"""Create necessary output directories"""
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
def save_outputs(recognition_results, image_path, save_dir):
"""Save JSON and markdown outputs"""
basename = os.path.splitext(os.path.basename(image_path))[0]
# Save JSON file
json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(recognition_results, f, ensure_ascii=False, indent=2)
# Generate and save markdown file
markdown_converter = MarkdownConverter()
markdown_content = markdown_converter.convert(recognition_results)
markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
with open(markdown_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
return json_path
def crop_margin(img: Image.Image) -> Image.Image:
"""Crop margins from image"""
try:
width, height = img.size
if width == 0 or height == 0:
print("Warning: Image has zero width or height")
return img
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
if coords is None:
return img
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
# Ensure crop coordinates are within image bounds
a = max(0, a)
b = max(0, b)
w = min(w, width - a)
h = min(h, height - b)
# Only crop if we have a valid region
if w > 0 and h > 0:
return img.crop((a, b, a + w, b + h))
return img
except Exception as e:
print(f"crop_margin error: {str(e)}")
return img # Return original image on error