[init] initial commit
154
.gitignore
vendored
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
coverage.xml
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
*.iml
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode/
|
||||||
|
!.vscode/settings.json
|
||||||
|
!.vscode/tasks.json
|
||||||
|
!.vscode/launch.json
|
||||||
|
!.vscode/extensions.json
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
Thumbs.db
|
||||||
|
ehthumbs.db
|
||||||
|
Desktop.ini
|
||||||
|
|
||||||
|
fusion_result.json
|
||||||
|
kernel_meta/
|
34
.pre-commit-config.yaml
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
repos:
|
||||||
|
# 1. isort - 自动排序 Python imports
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 6.0.1 # 使用固定版本号
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
name: isort (python)
|
||||||
|
args: [--profile=black] # 与 Black 兼容的配置
|
||||||
|
language: python
|
||||||
|
|
||||||
|
# 2. Black - 自动格式化 Python 代码
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 25.1.0 # 使用固定版本号
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
language: python
|
||||||
|
|
||||||
|
# 3. flake8 - Python 静态检查
|
||||||
|
- repo: https://github.com/pycqa/flake8
|
||||||
|
rev: 7.2.0
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args: [--max-line-length=120, --ignore=E203] # 设置行长度为 120
|
||||||
|
additional_dependencies: [flake8-bugbear==24.12.12] # 可选:增强检查
|
||||||
|
|
||||||
|
# 4. pre-commit-hooks - 通用 Git 钩子
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace # 删除行尾空格
|
||||||
|
- id: end-of-file-fixer # 确保文件以换行符结束
|
||||||
|
- id: check-yaml # 验证 YAML 文件语法
|
||||||
|
- id: check-added-large-files # 阻止大文件提交
|
||||||
|
args: ["--maxkb=512"]
|
9
LICENSE
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright 2025 ByteDance Ltd. and/or its affiliates
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
188
README.md
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
<div align="center">
|
||||||
|
<img src="./assets/dolphin.png" width="300">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://arxiv.org/abs/2505.14059">
|
||||||
|
<img src="https://img.shields.io/badge/Paper-arXiv-red">
|
||||||
|
</a>
|
||||||
|
<a href="https://huggingface.co/ByteDance/Dolphin">
|
||||||
|
<img src="https://img.shields.io/badge/HuggingFace-Dolphin-yellow">
|
||||||
|
</a>
|
||||||
|
<a href="http://115.190.42.15:8888/dolphin/">
|
||||||
|
<img src="https://img.shields.io/badge/Demo-Dolphin-blue">
|
||||||
|
</a>
|
||||||
|
<a href="https://github.com/bytedance/Dolphin">
|
||||||
|
<img src="https://img.shields.io/badge/Code-Github-green">
|
||||||
|
</a>
|
||||||
|
<a href="https://opensource.org/licenses/MIT">
|
||||||
|
<img src="https://img.shields.io/badge/License-MIT-lightgray">
|
||||||
|
</a>
|
||||||
|
<br>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./assets/demo.gif" width="800">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting
|
||||||
|
|
||||||
|
Dolphin (**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g) is a novel multimodal document image parsing model following an analyze-then-parse paradigm. This repository contains the demo code and pre-trained models for Dolphin.
|
||||||
|
|
||||||
|
## 📑 Overview
|
||||||
|
|
||||||
|
Document image parsing is challenging due to its complexly intertwined elements such as text paragraphs, figures, formulas, and tables. Dolphin addresses these challenges through a two-stage approach:
|
||||||
|
|
||||||
|
1. **🔍 Stage 1**: Comprehensive page-level layout analysis by generating element sequence in natural reading order
|
||||||
|
2. **🧩 Stage 2**: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./assets/framework.png" width="680">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Dolphin achieves promising performance across diverse page-level and element-level parsing tasks while ensuring superior efficiency through its lightweight architecture and parallel parsing mechanism.
|
||||||
|
|
||||||
|
## 🚀 Demo
|
||||||
|
|
||||||
|
Try our demo on [Demo-Dolphin](http://115.190.42.15:8888/dolphin/).
|
||||||
|
|
||||||
|
|
||||||
|
## 📅 Changelog
|
||||||
|
- 🔥 **2025.05.21** Our demo is released at [link](http://115.190.42.15:8888/dolphin/). Check it out!
|
||||||
|
- 🔥 **2025.05.20** The pretrained model and inference code of Dolphin are released.
|
||||||
|
- 🔥 **2025.05.16** Our paper has been accepted by ACL 2025. Paper link: [arXiv](https://arxiv.org/abs/2505.14059).
|
||||||
|
|
||||||
|
## 🛠️ Installation
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ByteDance/Dolphin.git
|
||||||
|
cd Dolphin
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install the dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Download the pre-trained models using one of the following options:
|
||||||
|
|
||||||
|
**Option A: Original Model Format (config-based)**
|
||||||
|
|
||||||
|
Download from [Baidu Yun](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) or [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) and put them in the `./checkpoints` folder.
|
||||||
|
|
||||||
|
**Option B: Hugging Face Model Format**
|
||||||
|
|
||||||
|
Visit our Huggingface [model card](https://huggingface.co/ByteDance/Dolphin), or download model by:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download the model from Hugging Face Hub
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/ByteDance/Dolphin ./hf_model
|
||||||
|
# Or use the Hugging Face CLI
|
||||||
|
huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model
|
||||||
|
```
|
||||||
|
|
||||||
|
## ⚡ Inference
|
||||||
|
|
||||||
|
Dolphin provides two inference frameworks with support for two parsing granularities:
|
||||||
|
- **Page-level Parsing**: Parse the entire document image into a structured JSON and Markdown format
|
||||||
|
- **Element-level Parsing**: Parse individual document elements (text, table, formula)
|
||||||
|
|
||||||
|
### 📄 Page-level Parsing
|
||||||
|
|
||||||
|
#### Using Original Framework (config-based)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Process a single document image
|
||||||
|
python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
|
||||||
|
|
||||||
|
# Process all document images in a directory
|
||||||
|
python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results
|
||||||
|
|
||||||
|
# Process with custom batch size for parallel element decoding
|
||||||
|
python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using Hugging Face Framework
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Process a single document image
|
||||||
|
python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
|
||||||
|
|
||||||
|
# Process all document images in a directory
|
||||||
|
python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results
|
||||||
|
|
||||||
|
# Process with custom batch size for parallel element decoding
|
||||||
|
python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🧩 Element-level Parsing
|
||||||
|
|
||||||
|
#### Using Original Framework (config-based)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Process a single table image
|
||||||
|
python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table
|
||||||
|
|
||||||
|
# Process a single formula image
|
||||||
|
python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
|
||||||
|
|
||||||
|
# Process a single text paragraph image
|
||||||
|
python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using Hugging Face Framework
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Process a single table image
|
||||||
|
python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table
|
||||||
|
|
||||||
|
# Process a single formula image
|
||||||
|
python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
|
||||||
|
|
||||||
|
# Process a single text paragraph image
|
||||||
|
python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🌟 Key Features
|
||||||
|
|
||||||
|
- 🔄 Two-stage analyze-then-parse approach based on a single VLM
|
||||||
|
- 📊 Promising performance on document parsing tasks
|
||||||
|
- 🔍 Natural reading order element sequence generation
|
||||||
|
- 🧩 Heterogeneous anchor prompting for different document elements
|
||||||
|
- ⏱️ Efficient parallel parsing mechanism
|
||||||
|
- 🤗 Support for Hugging Face Transformers for easier integration
|
||||||
|
|
||||||
|
|
||||||
|
## 📮 Notice
|
||||||
|
**Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the issue. We are continuously working to optimize and improve the model.
|
||||||
|
|
||||||
|
## 💖 Acknowledgement
|
||||||
|
|
||||||
|
We would like to acknowledge the following open-source projects that provided inspiration and reference for this work:
|
||||||
|
- [Donut](https://github.com/clovaai/donut/)
|
||||||
|
- [Nougat](https://github.com/facebookresearch/nougat)
|
||||||
|
- [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0)
|
||||||
|
- [MinerU](https://github.com/opendatalab/MinerU/tree/master)
|
||||||
|
- [Swin](https://github.com/microsoft/Swin-Transformer)
|
||||||
|
- [Hugging Face Transformers](https://github.com/huggingface/transformers)
|
||||||
|
|
||||||
|
## 📝 Citation
|
||||||
|
|
||||||
|
If you find this code useful for your research, please use the following BibTeX entry.
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{dolphin2025,
|
||||||
|
title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting},
|
||||||
|
author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and Tang, Jingqun and Liu, Hao and Huang, Can},
|
||||||
|
year={2025},
|
||||||
|
booktitle={Proceedings of the 65rd Annual Meeting of the Association for Computational Linguistics (ACL)}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
[](https://www.star-history.com/#bytedance/Dolphin&Date)
|
BIN
assets/demo.gif
Normal file
After Width: | Height: | Size: 3.1 MiB |
BIN
assets/dolphin.png
Normal file
After Width: | Height: | Size: 81 KiB |
BIN
assets/framework.png
Normal file
After Width: | Height: | Size: 1.9 MiB |
197
chat.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from utils.model import DonutConfig, DonutModel, SwinEncoder
|
||||||
|
from utils.processor import DolphinProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def try_rename_lagacy_weights(ckpt, output_path=""):
|
||||||
|
if "state_dict" in ckpt.keys():
|
||||||
|
ckpt = ckpt["state_dict"]
|
||||||
|
if "module" in ckpt.keys():
|
||||||
|
ckpt = ckpt["module"]
|
||||||
|
new_ckpt = OrderedDict()
|
||||||
|
for k, v in ckpt.items():
|
||||||
|
if k.startswith("model."):
|
||||||
|
k = k[len("model.") :]
|
||||||
|
if k.startswith("encoder"):
|
||||||
|
new_ckpt["vpm" + k[len("encoder") :]] = v
|
||||||
|
elif k.startswith("decoder"):
|
||||||
|
new_ckpt["llm" + k[len("encoder") :]] = v
|
||||||
|
else:
|
||||||
|
new_ckpt[k] = v
|
||||||
|
if output_path:
|
||||||
|
torch.save(new_ckpt, output_path)
|
||||||
|
return new_ckpt
|
||||||
|
|
||||||
|
|
||||||
|
def convert_listconfig_to_list(config):
|
||||||
|
new_config = {}
|
||||||
|
for k, v in config.items():
|
||||||
|
if isinstance(v, ListConfig):
|
||||||
|
new_config[k] = list(v)
|
||||||
|
else:
|
||||||
|
new_config[k] = v
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
|
class DOLPHIN:
|
||||||
|
def __init__(self, config, ckpt_path="") -> None:
|
||||||
|
self.model_args = config.model
|
||||||
|
self.swin_args = config.model.pop("swin_args")
|
||||||
|
self.swin_args = convert_listconfig_to_list(self.swin_args)
|
||||||
|
|
||||||
|
vision_tower = SwinEncoder(
|
||||||
|
input_size=self.swin_args["img_size"],
|
||||||
|
patch_size=self.swin_args["patch_size"],
|
||||||
|
embed_dim=self.swin_args["embed_dim"],
|
||||||
|
window_size=self.swin_args["window_size"],
|
||||||
|
encoder_layer=self.swin_args["encoder_layer"],
|
||||||
|
num_heads=self.swin_args["num_heads"],
|
||||||
|
align_long_axis=self.swin_args["align_long_axis"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path)
|
||||||
|
self.tokenizer.pad_token = "<pad>"
|
||||||
|
self.tokenizer.bos_token = "<s>"
|
||||||
|
self.tokenizer.eos_token = "</s>"
|
||||||
|
self.tokenizer.unk_token = "<unk>"
|
||||||
|
|
||||||
|
if self.model_args.get("extra_answer_tokens", False):
|
||||||
|
# print("Allowing multitask training: adding <Answer/> to the tokenizer.")
|
||||||
|
prompt_end_token = " <Answer/>"
|
||||||
|
self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))})
|
||||||
|
self.tokenizer._prompt_end_token = prompt_end_token
|
||||||
|
self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token)
|
||||||
|
|
||||||
|
donut_config = DonutConfig(
|
||||||
|
decoder_layer=self.model_args.decoder_layer,
|
||||||
|
max_length=self.model_args.max_length,
|
||||||
|
max_position_embeddings=self.model_args.max_position_embeddings,
|
||||||
|
hidden_dimension=self.model_args.hidden_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer)
|
||||||
|
if self.model_args.model_name_or_path:
|
||||||
|
ckpt = torch.load(self.model_args.model_name_or_path)
|
||||||
|
ckpt = try_rename_lagacy_weights(ckpt)
|
||||||
|
self.model.load_state_dict(ckpt, strict=True)
|
||||||
|
|
||||||
|
self.model.to("cuda")
|
||||||
|
self.model.eval()
|
||||||
|
transform_args = {
|
||||||
|
"input_size": self.swin_args["img_size"],
|
||||||
|
"max_length": self.model_args.max_length,
|
||||||
|
}
|
||||||
|
self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args)
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
question,
|
||||||
|
image,
|
||||||
|
return_raw=False,
|
||||||
|
return_score=False,
|
||||||
|
return_img_size=False,
|
||||||
|
only_return_img_size=False,
|
||||||
|
max_batch_size=16,
|
||||||
|
):
|
||||||
|
|
||||||
|
def _preprocess_image(image):
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image).convert("RGB")
|
||||||
|
if return_img_size or only_return_img_size:
|
||||||
|
image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True)
|
||||||
|
else:
|
||||||
|
image_tensor = self.processor.process_image_for_inference(image, return_img_size=False)
|
||||||
|
ori_size = None
|
||||||
|
return image_tensor, ori_size
|
||||||
|
|
||||||
|
def _preprocess_prompt(question):
|
||||||
|
if self.model_args.get("extra_answer_tokens", False):
|
||||||
|
if self.tokenizer._prompt_end_token not in question:
|
||||||
|
question = question + self.tokenizer._prompt_end_token
|
||||||
|
prompt_ids = self.processor.process_prompt_for_inference(question)
|
||||||
|
return prompt_ids
|
||||||
|
|
||||||
|
def _preprocess_prompt_batch(question):
|
||||||
|
if self.model_args.get("extra_answer_tokens", False):
|
||||||
|
for i in range(len(question)):
|
||||||
|
if self.tokenizer._prompt_end_token not in question[i]:
|
||||||
|
question[i] = question[i] + self.tokenizer._prompt_end_token
|
||||||
|
if not question[i].startswith("<s>"):
|
||||||
|
question[i] = "<s>" + question[i]
|
||||||
|
return question
|
||||||
|
|
||||||
|
def _postprocess(output, question):
|
||||||
|
output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "")
|
||||||
|
if self.model_args.get("extra_answer_tokens", False):
|
||||||
|
output = output.split(self.tokenizer._prompt_end_token)[-1]
|
||||||
|
return output
|
||||||
|
|
||||||
|
if isinstance(question, list):
|
||||||
|
image_tensor_list = []
|
||||||
|
for i in image:
|
||||||
|
image_tensor, ori_size = _preprocess_image(i)
|
||||||
|
image_tensor_list.append(image_tensor)
|
||||||
|
image_tensor = torch.cat(image_tensor_list, dim=0)
|
||||||
|
|
||||||
|
question = _preprocess_prompt_batch(question)
|
||||||
|
self.processor.tokenizer.padding_side = "left"
|
||||||
|
prompt_ids = self.processor.tokenizer(
|
||||||
|
question, add_special_tokens=False, return_tensors="pt", padding=True
|
||||||
|
).input_ids
|
||||||
|
else:
|
||||||
|
image_tensor, ori_size = _preprocess_image(image)
|
||||||
|
prompt_ids = _preprocess_prompt(question)
|
||||||
|
|
||||||
|
if only_return_img_size:
|
||||||
|
return ori_size
|
||||||
|
|
||||||
|
model_output_batch = []
|
||||||
|
for i in range(0, image_tensor.shape[0], max_batch_size):
|
||||||
|
image_tensor_batch = image_tensor[i : i + max_batch_size]
|
||||||
|
prompt_ids_batch = prompt_ids[i : i + max_batch_size]
|
||||||
|
model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch)
|
||||||
|
model_output_batch.append(model_output)
|
||||||
|
model_output = {}
|
||||||
|
for k, v in model_output_batch[0].items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
model_output[k] = sum(
|
||||||
|
[v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], [])
|
||||||
|
|
||||||
|
if return_raw:
|
||||||
|
if return_img_size:
|
||||||
|
return model_output, ori_size
|
||||||
|
return model_output
|
||||||
|
else:
|
||||||
|
if isinstance(question, list):
|
||||||
|
output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))]
|
||||||
|
score = model_output["scores"]
|
||||||
|
else:
|
||||||
|
output = _postprocess(model_output["repetitions"][0], question)
|
||||||
|
score = model_output["scores"][0]
|
||||||
|
if return_score:
|
||||||
|
return output, score
|
||||||
|
if return_img_size:
|
||||||
|
return output, ori_size
|
||||||
|
return output
|
17
config/Dolphin.yaml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
model:
|
||||||
|
model_name_or_path: "./checkpoints/dolphin_model.bin"
|
||||||
|
tokenizer_path: "./checkpoints/dolphin_tokenizer.json"
|
||||||
|
extra_answer_tokens: True # add <Answer/> token
|
||||||
|
max_length: 4096
|
||||||
|
decoder_layer: 10
|
||||||
|
max_position_embeddings: 4096
|
||||||
|
hidden_dimension: 1024
|
||||||
|
swin_args:
|
||||||
|
name: 'swin'
|
||||||
|
img_size: [896, 896]
|
||||||
|
patch_size: 4
|
||||||
|
embed_dim: 128
|
||||||
|
align_long_axis: False
|
||||||
|
window_size: 7
|
||||||
|
encoder_layer: [2, 2, 14, 2]
|
||||||
|
num_heads: [4, 8, 16, 32]
|
BIN
demo/element_imgs/block_formula.jpeg
Normal file
After Width: | Height: | Size: 90 KiB |
BIN
demo/element_imgs/line_formula.jpeg
Normal file
After Width: | Height: | Size: 54 KiB |
BIN
demo/element_imgs/para_1.jpg
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
demo/element_imgs/para_2.jpg
Normal file
After Width: | Height: | Size: 68 KiB |
BIN
demo/element_imgs/para_3.jpeg
Normal file
After Width: | Height: | Size: 82 KiB |
BIN
demo/element_imgs/table_1.jpeg
Normal file
After Width: | Height: | Size: 179 KiB |
BIN
demo/element_imgs/table_2.jpeg
Normal file
After Width: | Height: | Size: 396 KiB |
BIN
demo/page_imgs/page_1.jpeg
Normal file
After Width: | Height: | Size: 1.4 MiB |
BIN
demo/page_imgs/page_2.jpeg
Normal file
After Width: | Height: | Size: 1.4 MiB |
BIN
demo/page_imgs/page_3.jpeg
Normal file
After Width: | Height: | Size: 439 KiB |
BIN
demo/page_imgs/page_4.png
Executable file
After Width: | Height: | Size: 363 KiB |
BIN
demo/page_imgs/page_5.jpg
Executable file
After Width: | Height: | Size: 626 KiB |
129
demo_element.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from chat import DOLPHIN
|
||||||
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def process_element(image_path, model, element_type, save_dir=None):
|
||||||
|
"""Process a single element image (text, table, formula)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the element image
|
||||||
|
model: DOLPHIN model instance
|
||||||
|
element_type: Type of element ('text', 'table', 'formula')
|
||||||
|
save_dir: Directory to save results (default: same as input directory)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed content of the element and recognition results
|
||||||
|
"""
|
||||||
|
# Load and prepare image
|
||||||
|
pil_image = Image.open(image_path).convert("RGB")
|
||||||
|
pil_image = crop_margin(pil_image)
|
||||||
|
|
||||||
|
# Select appropriate prompt based on element type
|
||||||
|
if element_type == "table":
|
||||||
|
prompt = "Parse the table in the image."
|
||||||
|
label = "tab"
|
||||||
|
elif element_type == "formula":
|
||||||
|
prompt = "Read text in the image."
|
||||||
|
label = "formula"
|
||||||
|
else: # Default to text
|
||||||
|
prompt = "Read text in the image."
|
||||||
|
label = "text"
|
||||||
|
|
||||||
|
# Process the element
|
||||||
|
result = model.chat(prompt, pil_image)
|
||||||
|
|
||||||
|
# Create recognition result in the same format as the document parser
|
||||||
|
recognition_result = [
|
||||||
|
{
|
||||||
|
"label": label,
|
||||||
|
"text": result.strip(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Save results if save_dir is provided
|
||||||
|
if save_dir:
|
||||||
|
save_outputs(recognition_result, image_path, save_dir)
|
||||||
|
print(f"Results saved to {save_dir}")
|
||||||
|
|
||||||
|
return result, recognition_result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
|
||||||
|
parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
|
||||||
|
parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
|
||||||
|
parser.add_argument(
|
||||||
|
"--element_type",
|
||||||
|
type=str,
|
||||||
|
choices=["text", "table", "formula"],
|
||||||
|
default="text",
|
||||||
|
help="Type of element to process (text, table, formula)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory to save parsing results (default: same as input directory)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load Model
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
model = DOLPHIN(config)
|
||||||
|
|
||||||
|
# Set save directory
|
||||||
|
save_dir = args.save_dir or (
|
||||||
|
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
||||||
|
)
|
||||||
|
setup_output_dirs(save_dir)
|
||||||
|
|
||||||
|
# Collect Images
|
||||||
|
if os.path.isdir(args.input_path):
|
||||||
|
image_files = []
|
||||||
|
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
||||||
|
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
||||||
|
image_files = sorted(image_files)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(args.input_path):
|
||||||
|
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
||||||
|
image_files = [args.input_path]
|
||||||
|
|
||||||
|
total_samples = len(image_files)
|
||||||
|
print(f"\nTotal samples to process: {total_samples}")
|
||||||
|
|
||||||
|
# Process images one by one
|
||||||
|
for image_path in image_files:
|
||||||
|
print(f"\nProcessing {image_path}")
|
||||||
|
try:
|
||||||
|
result, recognition_result = process_element(
|
||||||
|
image_path=image_path,
|
||||||
|
model=model,
|
||||||
|
element_type=args.element_type,
|
||||||
|
save_dir=save_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.print_results:
|
||||||
|
print("\nRecognition result:")
|
||||||
|
print(result)
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {image_path}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
193
demo_element_hf.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class DOLPHIN:
|
||||||
|
def __init__(self, model_id_or_path):
|
||||||
|
"""Initialize the Hugging Face model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id_or_path: Path to local model or Hugging Face model ID
|
||||||
|
"""
|
||||||
|
# Load model from local path or Hugging Face hub
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
|
||||||
|
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Set device and precision
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model = self.model.half() # Always use half precision by default
|
||||||
|
|
||||||
|
# set tokenizer
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
|
||||||
|
def chat(self, prompt, image):
|
||||||
|
"""Process an image with the given prompt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt to guide the model
|
||||||
|
image: PIL Image to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text from the model
|
||||||
|
"""
|
||||||
|
# Prepare image
|
||||||
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
||||||
|
pixel_values = pixel_values.half()
|
||||||
|
|
||||||
|
# Prepare prompt
|
||||||
|
prompt = f"<s>{prompt} <Answer/>"
|
||||||
|
prompt_ids = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors="pt"
|
||||||
|
).input_ids.to(self.device)
|
||||||
|
|
||||||
|
decoder_attention_mask = torch.ones_like(prompt_ids)
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
outputs = self.model.generate(
|
||||||
|
pixel_values=pixel_values.to(self.device),
|
||||||
|
decoder_input_ids=prompt_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
min_length=1,
|
||||||
|
max_length=4096,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
use_cache=True,
|
||||||
|
bad_words_ids=[[self.tokenizer.unk_token_id]],
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the output
|
||||||
|
sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
|
||||||
|
sequence = sequence.replace(prompt, "").replace("<pad>", "").replace("</s>", "").strip()
|
||||||
|
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def process_element(image_path, model, element_type, save_dir=None):
|
||||||
|
"""Process a single element image (text, table, formula)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the element image
|
||||||
|
model: HFModel model instance
|
||||||
|
element_type: Type of element ('text', 'table', 'formula')
|
||||||
|
save_dir: Directory to save results (default: same as input directory)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed content of the element and recognition results
|
||||||
|
"""
|
||||||
|
# Load and prepare image
|
||||||
|
pil_image = Image.open(image_path).convert("RGB")
|
||||||
|
pil_image = crop_margin(pil_image)
|
||||||
|
|
||||||
|
# Select appropriate prompt based on element type
|
||||||
|
if element_type == "table":
|
||||||
|
prompt = "Parse the table in the image."
|
||||||
|
label = "tab"
|
||||||
|
elif element_type == "formula":
|
||||||
|
prompt = "Read text in the image."
|
||||||
|
label = "formula"
|
||||||
|
else: # Default to text
|
||||||
|
prompt = "Read text in the image."
|
||||||
|
label = "text"
|
||||||
|
|
||||||
|
# Process the element
|
||||||
|
result = model.chat(prompt, pil_image)
|
||||||
|
|
||||||
|
# Create recognition result in the same format as the document parser
|
||||||
|
recognition_result = [
|
||||||
|
{
|
||||||
|
"label": label,
|
||||||
|
"text": result.strip(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Save results if save_dir is provided
|
||||||
|
if save_dir:
|
||||||
|
save_outputs(recognition_result, image_path, save_dir)
|
||||||
|
print(f"Results saved to {save_dir}")
|
||||||
|
|
||||||
|
return result, recognition_result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
|
||||||
|
parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
|
||||||
|
parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
|
||||||
|
parser.add_argument(
|
||||||
|
"--element_type",
|
||||||
|
type=str,
|
||||||
|
choices=["text", "table", "formula"],
|
||||||
|
default="text",
|
||||||
|
help="Type of element to process (text, table, formula)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory to save parsing results (default: same as input directory)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load Model
|
||||||
|
model = DOLPHIN(args.model_path)
|
||||||
|
|
||||||
|
# Set save directory
|
||||||
|
save_dir = args.save_dir or (
|
||||||
|
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
||||||
|
)
|
||||||
|
setup_output_dirs(save_dir)
|
||||||
|
|
||||||
|
# Collect Images
|
||||||
|
if os.path.isdir(args.input_path):
|
||||||
|
image_files = []
|
||||||
|
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
||||||
|
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
||||||
|
image_files = sorted(image_files)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(args.input_path):
|
||||||
|
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
||||||
|
image_files = [args.input_path]
|
||||||
|
|
||||||
|
total_samples = len(image_files)
|
||||||
|
print(f"\nTotal samples to process: {total_samples}")
|
||||||
|
|
||||||
|
# Process images one by one
|
||||||
|
for image_path in image_files:
|
||||||
|
print(f"\nProcessing {image_path}")
|
||||||
|
try:
|
||||||
|
result, recognition_result = process_element(
|
||||||
|
image_path=image_path,
|
||||||
|
model=model,
|
||||||
|
element_type=args.element_type,
|
||||||
|
save_dir=save_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.print_results:
|
||||||
|
print("\nRecognition result:")
|
||||||
|
print(result)
|
||||||
|
print("-" * 40)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {image_path}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
171
demo_page.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from chat import DOLPHIN
|
||||||
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def process_page(image_path, model, save_dir, max_batch_size):
|
||||||
|
"""Parse document images with two stages"""
|
||||||
|
# Stage 1: Page-level layout and reading order parsing
|
||||||
|
pil_image = Image.open(image_path).convert("RGB")
|
||||||
|
layout_output = model.chat("Parse the reading order of this document.", pil_image)
|
||||||
|
|
||||||
|
# Stage 2: Element-level content parsing
|
||||||
|
padded_image, dims = prepare_image(pil_image)
|
||||||
|
recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
|
||||||
|
|
||||||
|
# Save outputs
|
||||||
|
json_path = save_outputs(recognition_results, image_path, save_dir)
|
||||||
|
|
||||||
|
return json_path, recognition_results
|
||||||
|
|
||||||
|
|
||||||
|
def process_elements(layout_results, padded_image, dims, model, max_batch_size):
|
||||||
|
"""Parse all document elements with parallel decoding"""
|
||||||
|
layout_results = parse_layout_string(layout_results)
|
||||||
|
|
||||||
|
text_table_elements = [] # Elements that need processing
|
||||||
|
figure_results = [] # Figure elements (no processing needed)
|
||||||
|
previous_box = None
|
||||||
|
reading_order = 0
|
||||||
|
|
||||||
|
# Collect elements for processing
|
||||||
|
for bbox, label in layout_results:
|
||||||
|
try:
|
||||||
|
# Adjust coordinates
|
||||||
|
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
|
||||||
|
bbox, padded_image, dims, previous_box
|
||||||
|
)
|
||||||
|
|
||||||
|
# Crop and parse element
|
||||||
|
cropped = padded_image[y1:y2, x1:x2]
|
||||||
|
if cropped.size > 0:
|
||||||
|
if label == "fig":
|
||||||
|
# For figure regions, add empty text result immediately
|
||||||
|
figure_results.append(
|
||||||
|
{
|
||||||
|
"label": label,
|
||||||
|
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
||||||
|
"text": "",
|
||||||
|
"reading_order": reading_order,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For text or table regions, prepare for parsing
|
||||||
|
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
||||||
|
prompt = "Parse the table in the image." if label == "tab" else "Read text in the image."
|
||||||
|
text_table_elements.append(
|
||||||
|
{
|
||||||
|
"crop": pil_crop,
|
||||||
|
"prompt": prompt,
|
||||||
|
"label": label,
|
||||||
|
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
||||||
|
"reading_order": reading_order,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
reading_order += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing bbox with label {label}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse text/table elements in parallel
|
||||||
|
recognition_results = figure_results
|
||||||
|
if text_table_elements:
|
||||||
|
crops_list = [elem["crop"] for elem in text_table_elements]
|
||||||
|
prompts_list = [elem["prompt"] for elem in text_table_elements]
|
||||||
|
|
||||||
|
# Inference in batch
|
||||||
|
batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size)
|
||||||
|
|
||||||
|
# Add batch results to recognition_results
|
||||||
|
for i, result in enumerate(batch_results):
|
||||||
|
elem = text_table_elements[i]
|
||||||
|
recognition_results.append(
|
||||||
|
{
|
||||||
|
"label": elem["label"],
|
||||||
|
"bbox": elem["bbox"],
|
||||||
|
"text": result.strip(),
|
||||||
|
"reading_order": elem["reading_order"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort elements by reading order
|
||||||
|
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
|
||||||
|
|
||||||
|
return recognition_results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
|
||||||
|
parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
|
||||||
|
parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory to save parsing results (default: same as input directory)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Maximum number of document elements to parse in a single batch (default: 4)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load Model
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
model = DOLPHIN(config)
|
||||||
|
|
||||||
|
# Collect Document Images
|
||||||
|
if os.path.isdir(args.input_path):
|
||||||
|
image_files = []
|
||||||
|
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
||||||
|
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
||||||
|
image_files = sorted(image_files)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(args.input_path):
|
||||||
|
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
||||||
|
image_files = [args.input_path]
|
||||||
|
|
||||||
|
save_dir = args.save_dir or (
|
||||||
|
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
||||||
|
)
|
||||||
|
setup_output_dirs(save_dir)
|
||||||
|
|
||||||
|
total_samples = len(image_files)
|
||||||
|
print(f"\nTotal samples to process: {total_samples}")
|
||||||
|
|
||||||
|
# Process All Document Images
|
||||||
|
for image_path in image_files:
|
||||||
|
print(f"\nProcessing {image_path}")
|
||||||
|
try:
|
||||||
|
json_path, recognition_results = process_page(
|
||||||
|
image_path=image_path,
|
||||||
|
model=model,
|
||||||
|
save_dir=save_dir,
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Processing completed. Results saved to {save_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {image_path}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
288
demo_page_hf.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class DOLPHIN:
|
||||||
|
def __init__(self, model_id_or_path):
|
||||||
|
"""Initialize the Hugging Face model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id_or_path: Path to local model or Hugging Face model ID
|
||||||
|
"""
|
||||||
|
# Load model from local path or Hugging Face hub
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
|
||||||
|
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Set device and precision
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model = self.model.half() # Always use half precision by default
|
||||||
|
|
||||||
|
# set tokenizer
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
|
||||||
|
def chat(self, prompt, image):
|
||||||
|
"""Process an image or batch of images with the given prompt(s)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt or list of prompts to guide the model
|
||||||
|
image: PIL Image or list of PIL Images to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated text or list of texts from the model
|
||||||
|
"""
|
||||||
|
# Check if we're dealing with a batch
|
||||||
|
is_batch = isinstance(image, list)
|
||||||
|
|
||||||
|
if not is_batch:
|
||||||
|
# Single image, wrap it in a list for consistent processing
|
||||||
|
images = [image]
|
||||||
|
prompts = [prompt]
|
||||||
|
else:
|
||||||
|
# Batch of images
|
||||||
|
images = image
|
||||||
|
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
|
||||||
|
|
||||||
|
# Prepare image
|
||||||
|
batch_inputs = self.processor(images, return_tensors="pt", padding=True)
|
||||||
|
batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
|
||||||
|
|
||||||
|
# Prepare prompt
|
||||||
|
prompts = [f"<s>{p} <Answer/>" for p in prompts]
|
||||||
|
batch_prompt_inputs = self.tokenizer(
|
||||||
|
prompts,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
|
||||||
|
batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
outputs = self.model.generate(
|
||||||
|
pixel_values=batch_pixel_values,
|
||||||
|
decoder_input_ids=batch_prompt_ids,
|
||||||
|
decoder_attention_mask=batch_attention_mask,
|
||||||
|
min_length=1,
|
||||||
|
max_length=4096,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
use_cache=True,
|
||||||
|
bad_words_ids=[[self.tokenizer.unk_token_id]],
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
repetition_penalty=1.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process output
|
||||||
|
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
||||||
|
|
||||||
|
# Clean prompt text from output
|
||||||
|
results = []
|
||||||
|
for i, sequence in enumerate(sequences):
|
||||||
|
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
|
||||||
|
results.append(cleaned)
|
||||||
|
|
||||||
|
# Return a single result for single image input
|
||||||
|
if not is_batch:
|
||||||
|
return results[0]
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def process_page(image_path, model, save_dir, max_batch_size=None):
|
||||||
|
"""Parse document images with two stages"""
|
||||||
|
# Stage 1: Page-level layout and reading order parsing
|
||||||
|
pil_image = Image.open(image_path).convert("RGB")
|
||||||
|
layout_output = model.chat("Parse the reading order of this document.", pil_image)
|
||||||
|
|
||||||
|
# Stage 2: Element-level content parsing
|
||||||
|
padded_image, dims = prepare_image(pil_image)
|
||||||
|
recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
|
||||||
|
|
||||||
|
# Save outputs
|
||||||
|
json_path = save_outputs(recognition_results, image_path, save_dir)
|
||||||
|
|
||||||
|
return json_path, recognition_results
|
||||||
|
|
||||||
|
|
||||||
|
def process_elements(layout_results, padded_image, dims, model, max_batch_size=None):
|
||||||
|
"""Parse all document elements with parallel decoding"""
|
||||||
|
layout_results = parse_layout_string(layout_results)
|
||||||
|
|
||||||
|
# Store text and table elements separately
|
||||||
|
text_elements = [] # Text elements
|
||||||
|
table_elements = [] # Table elements
|
||||||
|
figure_results = [] # Image elements (no processing needed)
|
||||||
|
previous_box = None
|
||||||
|
reading_order = 0
|
||||||
|
|
||||||
|
# Collect elements to process and group by type
|
||||||
|
for bbox, label in layout_results:
|
||||||
|
try:
|
||||||
|
# Adjust coordinates
|
||||||
|
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
|
||||||
|
bbox, padded_image, dims, previous_box
|
||||||
|
)
|
||||||
|
|
||||||
|
# Crop and parse element
|
||||||
|
cropped = padded_image[y1:y2, x1:x2]
|
||||||
|
if cropped.size > 0:
|
||||||
|
if label == "fig":
|
||||||
|
# For figure regions, add empty text result immediately
|
||||||
|
figure_results.append(
|
||||||
|
{
|
||||||
|
"label": label,
|
||||||
|
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
||||||
|
"text": "",
|
||||||
|
"reading_order": reading_order,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Prepare element for parsing
|
||||||
|
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
||||||
|
element_info = {
|
||||||
|
"crop": pil_crop,
|
||||||
|
"label": label,
|
||||||
|
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
||||||
|
"reading_order": reading_order,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Group by type
|
||||||
|
if label == "tab":
|
||||||
|
table_elements.append(element_info)
|
||||||
|
else: # Text elements
|
||||||
|
text_elements.append(element_info)
|
||||||
|
|
||||||
|
reading_order += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing bbox with label {label}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Initialize results list
|
||||||
|
recognition_results = figure_results.copy()
|
||||||
|
|
||||||
|
# Process text elements (in batches)
|
||||||
|
if text_elements:
|
||||||
|
text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
|
||||||
|
recognition_results.extend(text_results)
|
||||||
|
|
||||||
|
# Process table elements (in batches)
|
||||||
|
if table_elements:
|
||||||
|
table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
|
||||||
|
recognition_results.extend(table_results)
|
||||||
|
|
||||||
|
# Sort elements by reading order
|
||||||
|
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
|
||||||
|
|
||||||
|
return recognition_results
|
||||||
|
|
||||||
|
|
||||||
|
def process_element_batch(elements, model, prompt, max_batch_size=None):
|
||||||
|
"""Process elements of the same type in batches"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Determine batch size
|
||||||
|
batch_size = len(elements)
|
||||||
|
if max_batch_size is not None and max_batch_size > 0:
|
||||||
|
batch_size = min(batch_size, max_batch_size)
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
for i in range(0, len(elements), batch_size):
|
||||||
|
batch_elements = elements[i:i+batch_size]
|
||||||
|
crops_list = [elem["crop"] for elem in batch_elements]
|
||||||
|
|
||||||
|
# Use the same prompt for all elements in the batch
|
||||||
|
prompts_list = [prompt] * len(crops_list)
|
||||||
|
|
||||||
|
# Batch inference
|
||||||
|
batch_results = model.chat(prompts_list, crops_list)
|
||||||
|
|
||||||
|
# Add results
|
||||||
|
for j, result in enumerate(batch_results):
|
||||||
|
elem = batch_elements[j]
|
||||||
|
results.append({
|
||||||
|
"label": elem["label"],
|
||||||
|
"bbox": elem["bbox"],
|
||||||
|
"text": result.strip(),
|
||||||
|
"reading_order": elem["reading_order"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
|
||||||
|
parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
|
||||||
|
parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory to save parsing results (default: same as input directory)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Maximum number of document elements to parse in a single batch (default: 16)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load Model
|
||||||
|
model = DOLPHIN(args.model_path)
|
||||||
|
|
||||||
|
# Collect Document Images
|
||||||
|
if os.path.isdir(args.input_path):
|
||||||
|
image_files = []
|
||||||
|
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
||||||
|
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
||||||
|
image_files = sorted(image_files)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(args.input_path):
|
||||||
|
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
||||||
|
image_files = [args.input_path]
|
||||||
|
|
||||||
|
save_dir = args.save_dir or (
|
||||||
|
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
||||||
|
)
|
||||||
|
setup_output_dirs(save_dir)
|
||||||
|
|
||||||
|
total_samples = len(image_files)
|
||||||
|
print(f"\nTotal samples to process: {total_samples}")
|
||||||
|
|
||||||
|
# Process All Document Images
|
||||||
|
for image_path in image_files:
|
||||||
|
print(f"\nProcessing {image_path}")
|
||||||
|
try:
|
||||||
|
json_path, recognition_results = process_page(
|
||||||
|
image_path=image_path,
|
||||||
|
model=model,
|
||||||
|
save_dir=save_dir,
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Processing completed. Results saved to {save_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {image_path}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
16
pyproject.toml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
include = '\.pyi?$'
|
||||||
|
exclude = '''
|
||||||
|
/(
|
||||||
|
\.git
|
||||||
|
| \.hg
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.tox
|
||||||
|
| \.venv
|
||||||
|
| _build
|
||||||
|
| buck-out
|
||||||
|
| build
|
||||||
|
| dist
|
||||||
|
)/
|
||||||
|
'''
|
11
requirements.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
albumentations==1.4.0
|
||||||
|
numpy==1.24.4
|
||||||
|
omegaconf==2.3.0
|
||||||
|
opencv-python==4.11.0.86
|
||||||
|
opencv-python-headless==4.5.5.64
|
||||||
|
pillow==9.3.0
|
||||||
|
timm==0.5.4
|
||||||
|
torch==2.1.0
|
||||||
|
torchvision==0.16.0
|
||||||
|
transformers==4.47.0
|
||||||
|
accelerate==1.6.0
|
442
utils/markdown_utils.py
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example input:
|
||||||
|
[
|
||||||
|
{"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6},
|
||||||
|
{"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7},
|
||||||
|
{"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8},
|
||||||
|
{"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9},
|
||||||
|
{"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_table_from_html(html_string):
|
||||||
|
"""Extract and clean table tags from HTML string"""
|
||||||
|
try:
|
||||||
|
table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
|
||||||
|
tables = table_pattern.findall(html_string)
|
||||||
|
tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
|
||||||
|
return '\n'.join(tables)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"extract_table_from_html error: {str(e)}")
|
||||||
|
return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownConverter:
|
||||||
|
"""Convert structured recognition results to Markdown format"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Define heading levels for different section types
|
||||||
|
self.heading_levels = {
|
||||||
|
'title': '#',
|
||||||
|
'sec': '##',
|
||||||
|
'sub_sec': '###'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define which labels need special handling
|
||||||
|
self.special_labels = {
|
||||||
|
'tab', 'fig', 'title', 'sec', 'sub_sec',
|
||||||
|
'list', 'formula', 'reference', 'alg'
|
||||||
|
}
|
||||||
|
|
||||||
|
def try_remove_newline(self, text: str) -> str:
|
||||||
|
try:
|
||||||
|
# Preprocess text to handle line breaks
|
||||||
|
text = text.strip()
|
||||||
|
text = text.replace('-\n', '')
|
||||||
|
|
||||||
|
# Handle Chinese text line breaks
|
||||||
|
def is_chinese(char):
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
|
||||||
|
lines = text.split('\n')
|
||||||
|
processed_lines = []
|
||||||
|
|
||||||
|
# Process all lines except the last one
|
||||||
|
for i in range(len(lines)-1):
|
||||||
|
current_line = lines[i].strip()
|
||||||
|
next_line = lines[i+1].strip()
|
||||||
|
|
||||||
|
# Always add the current line, but determine if we need a newline
|
||||||
|
if current_line: # If current line is not empty
|
||||||
|
if next_line: # If next line is not empty
|
||||||
|
# For Chinese text handling
|
||||||
|
if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
|
||||||
|
processed_lines.append(current_line)
|
||||||
|
else:
|
||||||
|
processed_lines.append(current_line + ' ')
|
||||||
|
else:
|
||||||
|
# Next line is empty, add current line with newline
|
||||||
|
processed_lines.append(current_line + '\n')
|
||||||
|
else:
|
||||||
|
# Current line is empty, add an empty line
|
||||||
|
processed_lines.append('\n')
|
||||||
|
|
||||||
|
# Add the last line
|
||||||
|
if lines and lines[-1].strip():
|
||||||
|
processed_lines.append(lines[-1].strip())
|
||||||
|
|
||||||
|
text = ''.join(processed_lines)
|
||||||
|
|
||||||
|
return text
|
||||||
|
except Exception as e:
|
||||||
|
print(f"try_remove_newline error: {str(e)}")
|
||||||
|
return text # Return original text on error
|
||||||
|
|
||||||
|
def _handle_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Process regular text content, preserving paragraph structure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
|
||||||
|
text = "$$" + text + "$$"
|
||||||
|
elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
|
||||||
|
text = "$" + text + "$"
|
||||||
|
|
||||||
|
# Process formulas in text before handling other text processing
|
||||||
|
text = self._process_formulas_in_text(text)
|
||||||
|
|
||||||
|
text = self.try_remove_newline(text)
|
||||||
|
|
||||||
|
# Return processed text
|
||||||
|
return text
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_text error: {str(e)}")
|
||||||
|
return text # Return original text on error
|
||||||
|
|
||||||
|
def _process_formulas_in_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Process mathematical formulas in text by iteratively finding and replacing formulas.
|
||||||
|
- Identify inline and block formulas
|
||||||
|
- Replace newlines within formulas with \\
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Define formula delimiters and their corresponding patterns
|
||||||
|
delimiters = [
|
||||||
|
('$$', '$$'), # Block formula with $$
|
||||||
|
('\\[', '\\]'), # Block formula with \[ \]
|
||||||
|
('$', '$'), # Inline formula with $
|
||||||
|
('\\(', '\\)') # Inline formula with \( \)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process the text by iterating through each delimiter type
|
||||||
|
result = text
|
||||||
|
|
||||||
|
for start_delim, end_delim in delimiters:
|
||||||
|
# Create a pattern that matches from start to end delimiter
|
||||||
|
# Using a custom approach to avoid issues with nested delimiters
|
||||||
|
current_pos = 0
|
||||||
|
processed_parts = []
|
||||||
|
|
||||||
|
while current_pos < len(result):
|
||||||
|
# Find the next start delimiter
|
||||||
|
start_pos = result.find(start_delim, current_pos)
|
||||||
|
if start_pos == -1:
|
||||||
|
# No more formulas of this type
|
||||||
|
processed_parts.append(result[current_pos:])
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add text before the formula
|
||||||
|
processed_parts.append(result[current_pos:start_pos])
|
||||||
|
|
||||||
|
# Find the matching end delimiter
|
||||||
|
end_pos = result.find(end_delim, start_pos + len(start_delim))
|
||||||
|
if end_pos == -1:
|
||||||
|
# No matching end delimiter, treat as regular text
|
||||||
|
processed_parts.append(result[start_pos:])
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract the formula content (without delimiters)
|
||||||
|
formula_content = result[start_pos + len(start_delim):end_pos]
|
||||||
|
|
||||||
|
# Process the formula content - replace newlines with \\
|
||||||
|
processed_formula = formula_content.replace('\n', ' \\\\ ')
|
||||||
|
|
||||||
|
# Add the processed formula with its delimiters
|
||||||
|
processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
|
||||||
|
|
||||||
|
# Move past this formula
|
||||||
|
current_pos = end_pos + len(end_delim)
|
||||||
|
|
||||||
|
# Update the result with processed text
|
||||||
|
result = ''.join(processed_parts)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_process_formulas_in_text error: {str(e)}")
|
||||||
|
return text # Return original text on error
|
||||||
|
|
||||||
|
def _remove_newline_in_heading(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove newline in heading
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle Chinese text line breaks
|
||||||
|
def is_chinese(char):
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
|
||||||
|
# Check if the text contains Chinese characters
|
||||||
|
if any(is_chinese(char) for char in text):
|
||||||
|
return text.replace('\n', '')
|
||||||
|
else:
|
||||||
|
return text.replace('\n', ' ')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_remove_newline_in_heading error: {str(e)}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _handle_heading(self, text: str, label: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert section headings to appropriate markdown format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
level = self.heading_levels.get(label, '#')
|
||||||
|
text = text.strip()
|
||||||
|
text = self._remove_newline_in_heading(text)
|
||||||
|
text = self._handle_text(text)
|
||||||
|
return f"{level} {text}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_heading error: {str(e)}")
|
||||||
|
return f"# Error processing heading: {text}\n\n"
|
||||||
|
|
||||||
|
def _handle_list_item(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert list items to markdown list format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return f"- {text.strip()}\n"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_list_item error: {str(e)}")
|
||||||
|
return f"- Error processing list item: {text}\n"
|
||||||
|
|
||||||
|
def _handle_figure(self, text: str, section_count: int) -> str:
|
||||||
|
"""
|
||||||
|
Convert base64 encoded image to markdown image syntax
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Determine image format (assuming PNG if not specified)
|
||||||
|
img_format = "png"
|
||||||
|
if text.startswith("data:image/"):
|
||||||
|
# Extract format from data URI
|
||||||
|
img_format = text.split(";")[0].split("/")[1]
|
||||||
|
elif ";" in text and "," in text:
|
||||||
|
# Already in data URI format
|
||||||
|
return f"\n\n"
|
||||||
|
else:
|
||||||
|
# Raw base64, convert to data URI
|
||||||
|
data_uri = f"data:image/{img_format};base64,{text}"
|
||||||
|
return f"\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_figure error: {str(e)}")
|
||||||
|
return f"*[Error processing figure: {str(e)}]*\n\n"
|
||||||
|
|
||||||
|
def _handle_table(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert table content to markdown format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
markdown_content = []
|
||||||
|
if '<table' in text.lower() or '<tr' in text.lower():
|
||||||
|
markdown_table = extract_table_from_html(text)
|
||||||
|
markdown_content.append(markdown_table + "\n")
|
||||||
|
else:
|
||||||
|
table_lines = text.split('\n')
|
||||||
|
if table_lines:
|
||||||
|
col_count = len(table_lines[0].split()) if table_lines[0] else 1
|
||||||
|
header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
|
||||||
|
markdown_content.append(header)
|
||||||
|
markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |')
|
||||||
|
for line in table_lines[1:]:
|
||||||
|
cells = line.split()
|
||||||
|
while len(cells) < col_count:
|
||||||
|
cells.append('')
|
||||||
|
markdown_content.append('| ' + ' | '.join(cells) + ' |')
|
||||||
|
return '\n'.join(markdown_content) + '\n\n'
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_table error: {str(e)}")
|
||||||
|
return f"*[Error processing table: {str(e)}]*\n\n"
|
||||||
|
|
||||||
|
def _handle_algorithm(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Process algorithm blocks with proper formatting
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Remove algorithm environment tags if present
|
||||||
|
text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
|
||||||
|
text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
|
||||||
|
text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
|
||||||
|
|
||||||
|
# Process the algorithm text
|
||||||
|
lines = text.strip().split('\n')
|
||||||
|
|
||||||
|
# Check if there's a caption or label
|
||||||
|
caption = ""
|
||||||
|
algorithm_text = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if '\\caption' in line:
|
||||||
|
# Extract caption text
|
||||||
|
caption_match = re.search(r'\\caption\{(.*?)\}', line)
|
||||||
|
if caption_match:
|
||||||
|
caption = f"**{caption_match.group(1)}**\n\n"
|
||||||
|
continue
|
||||||
|
elif '\\label' in line:
|
||||||
|
continue # Skip label lines
|
||||||
|
else:
|
||||||
|
algorithm_text.append(line)
|
||||||
|
|
||||||
|
# Join the algorithm text and wrap in code block
|
||||||
|
formatted_text = '\n'.join(algorithm_text)
|
||||||
|
|
||||||
|
# Return the formatted algorithm with caption
|
||||||
|
return f"{caption}```\n{formatted_text}\n```\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_algorithm error: {str(e)}")
|
||||||
|
return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
|
||||||
|
|
||||||
|
def _handle_formula(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Handle formula-specific content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Process the formula content
|
||||||
|
processed_text = self._process_formulas_in_text(text)
|
||||||
|
|
||||||
|
# For formula blocks, ensure they're properly formatted in markdown
|
||||||
|
if '$$' not in processed_text and '\\[' not in processed_text:
|
||||||
|
# If no block formula delimiters are present, wrap in $$ for block formula
|
||||||
|
processed_text = f'$${processed_text}$$'
|
||||||
|
|
||||||
|
return f"{processed_text}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_handle_formula error: {str(e)}")
|
||||||
|
return f"*[Error processing formula: {str(e)}]*\n\n"
|
||||||
|
|
||||||
|
def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
Convert recognition results to markdown format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
markdown_content = []
|
||||||
|
|
||||||
|
for section_count, result in enumerate(recognition_results):
|
||||||
|
try:
|
||||||
|
label = result.get('label', '')
|
||||||
|
text = result.get('text', '').strip()
|
||||||
|
|
||||||
|
# Skip empty text
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle different content types
|
||||||
|
if label in {'title', 'sec', 'sub_sec'}:
|
||||||
|
markdown_content.append(self._handle_heading(text, label))
|
||||||
|
elif label == 'list':
|
||||||
|
markdown_content.append(self._handle_list_item(text))
|
||||||
|
elif label == 'fig':
|
||||||
|
markdown_content.append(self._handle_figure(text, section_count))
|
||||||
|
elif label == 'tab':
|
||||||
|
markdown_content.append(self._handle_table(text))
|
||||||
|
elif label == 'alg':
|
||||||
|
markdown_content.append(self._handle_algorithm(text))
|
||||||
|
elif label == 'formula':
|
||||||
|
markdown_content.append(self._handle_formula(text))
|
||||||
|
elif label not in self.special_labels:
|
||||||
|
# Handle regular text (paragraphs, etc.)
|
||||||
|
processed_text = self._handle_text(text)
|
||||||
|
markdown_content.append(f"{processed_text}\n\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing item {section_count}: {str(e)}")
|
||||||
|
# Add a placeholder for the failed item
|
||||||
|
markdown_content.append(f"*[Error processing content]*\n\n")
|
||||||
|
|
||||||
|
# Join all content and apply post-processing
|
||||||
|
result = ''.join(markdown_content)
|
||||||
|
return self._post_process(result)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"convert error: {str(e)}")
|
||||||
|
return f"Error generating markdown content: {str(e)}"
|
||||||
|
|
||||||
|
def _post_process(self, markdown_content: str) -> str:
|
||||||
|
"""
|
||||||
|
Apply post-processing fixes to the generated markdown content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle author information
|
||||||
|
author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
|
||||||
|
|
||||||
|
def process_author_match(match):
|
||||||
|
# Extract author content
|
||||||
|
author_content = match.group(1)
|
||||||
|
# Process the author content
|
||||||
|
return self._handle_text(author_content)
|
||||||
|
|
||||||
|
# Replace \author{...} with processed content
|
||||||
|
markdown_content = author_pattern.sub(process_author_match, markdown_content)
|
||||||
|
|
||||||
|
# Handle special case where author is inside math environment
|
||||||
|
math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL)
|
||||||
|
match = math_author_pattern.search(markdown_content)
|
||||||
|
if match:
|
||||||
|
# Extract the author command
|
||||||
|
author_cmd = match.group(1)
|
||||||
|
# Extract content from author command
|
||||||
|
author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
|
||||||
|
if author_content_match:
|
||||||
|
# Get author content and process it
|
||||||
|
author_content = author_content_match.group(1)
|
||||||
|
processed_content = self._handle_text(author_content)
|
||||||
|
# Replace the entire $\author{...}$ block with processed content
|
||||||
|
markdown_content = markdown_content.replace(match.group(0), processed_content)
|
||||||
|
|
||||||
|
# Replace LaTeX abstract environment with plain text
|
||||||
|
markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}',
|
||||||
|
r'**Abstract** \1',
|
||||||
|
markdown_content,
|
||||||
|
flags=re.DOTALL)
|
||||||
|
|
||||||
|
# Replace standalone \begin{abstract} (without matching end)
|
||||||
|
markdown_content = re.sub(r'\\begin\{abstract\}',
|
||||||
|
r'**Abstract**',
|
||||||
|
markdown_content)
|
||||||
|
|
||||||
|
# Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
|
||||||
|
markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
|
||||||
|
r'\\tag{\1}',
|
||||||
|
markdown_content)
|
||||||
|
|
||||||
|
# Find the starting tag of the formula
|
||||||
|
markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\")
|
||||||
|
|
||||||
|
# Find the ending tag of the formula (ensure this is the only ending tag)
|
||||||
|
markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$")
|
||||||
|
|
||||||
|
# Fix other common LaTeX issues
|
||||||
|
replacements = [
|
||||||
|
# Fix spacing issues in subscripts and superscripts
|
||||||
|
(r'_ {', r'_{'),
|
||||||
|
(r'^ {', r'^{'),
|
||||||
|
|
||||||
|
# Fix potential issues with multiple consecutive newlines
|
||||||
|
(r'\n{3,}', r'\n\n')
|
||||||
|
]
|
||||||
|
|
||||||
|
for old, new in replacements:
|
||||||
|
markdown_content = re.sub(old, new, markdown_content)
|
||||||
|
|
||||||
|
return markdown_content
|
||||||
|
except Exception as e:
|
||||||
|
print(f"_post_process error: {str(e)}")
|
||||||
|
return markdown_content # Return original content if post-processing fails
|
477
utils/model.py
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2022-present NAVER Corp.
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
||||||
|
MIT License
|
||||||
|
This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
|
||||||
|
The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
|
||||||
|
This modified file is released under the same license.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from timm.models.swin_transformer import SwinTransformer
|
||||||
|
from torch import nn
|
||||||
|
from transformers import (
|
||||||
|
MBartConfig,
|
||||||
|
MBartForCausalLM,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
)
|
||||||
|
from transformers.file_utils import ModelOutput
|
||||||
|
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
class SwinEncoder(nn.Module):
|
||||||
|
r"""
|
||||||
|
Encoder based on SwinTransformer
|
||||||
|
Set the initial weights and configuration with a pretrained SwinTransformer and then
|
||||||
|
modify the detailed configurations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: Input image size (width, height)
|
||||||
|
align_long_axis: Whether to rotate image if height is greater than width
|
||||||
|
window_size: Window size(=patch size) of SwinTransformer
|
||||||
|
encoder_layer: Number of layers of SwinTransformer encoder
|
||||||
|
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
|
||||||
|
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size,
|
||||||
|
align_long_axis: bool = False,
|
||||||
|
window_size: int = 7,
|
||||||
|
encoder_layer: List[int] = [2, 2, 14, 2],
|
||||||
|
patch_size: int = [4, 4],
|
||||||
|
embed_dim: int = 128,
|
||||||
|
num_heads: List[int] = [4, 8, 16, 32],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(input_size, int):
|
||||||
|
input_size = [input_size, input_size]
|
||||||
|
self.input_size = input_size
|
||||||
|
self.align_long_axis = align_long_axis
|
||||||
|
self.window_size = window_size
|
||||||
|
self.encoder_layer = encoder_layer
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.model = SwinTransformer(
|
||||||
|
img_size=self.input_size,
|
||||||
|
depths=self.encoder_layer,
|
||||||
|
window_size=self.window_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
embed_dim=self.embed_dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_classes=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (batch_size, num_channels, height, width)
|
||||||
|
"""
|
||||||
|
x = self.model.patch_embed(x)
|
||||||
|
x = self.model.pos_drop(x)
|
||||||
|
x = self.model.layers(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16."""
|
||||||
|
|
||||||
|
def _set_dtype(self, dtype):
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
ret = super().forward(x.type(dtype=self._dtype))
|
||||||
|
return ret.type(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class BARTDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Decoder based on Multilingual BART
|
||||||
|
Set the initial weights and configuration with a pretrained multilingual BART model,
|
||||||
|
and modify the detailed configurations as a Donut decoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_layer:
|
||||||
|
Number of layers of BARTDecoder
|
||||||
|
max_position_embeddings:
|
||||||
|
The maximum sequence length to be trained
|
||||||
|
name_or_path:
|
||||||
|
Name of a pretrained model name either registered in huggingface.co. or saved in local,
|
||||||
|
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
decoder_layer: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
hidden_dimension: int = 1024,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder_layer = decoder_layer
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_dimension = hidden_dimension
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.model = MBartForCausalLM(
|
||||||
|
config=MBartConfig(
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
is_decoder=True,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
add_cross_attention=True,
|
||||||
|
decoder_layers=self.decoder_layer,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
vocab_size=len(self.tokenizer),
|
||||||
|
scale_embedding=True,
|
||||||
|
add_final_layer_norm=True,
|
||||||
|
d_model=self.hidden_dimension,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# self.model.config.is_encoder_decoder = True # to get cross-attention
|
||||||
|
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
|
||||||
|
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
|
||||||
|
|
||||||
|
def add_special_tokens(self, list_of_tokens: List[str]):
|
||||||
|
"""
|
||||||
|
Add special tokens to tokenizer and resize the token embeddings
|
||||||
|
"""
|
||||||
|
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
|
||||||
|
if newly_added_num > 0:
|
||||||
|
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||||
|
|
||||||
|
def add_tokens(self, list_of_tokens: List[str]):
|
||||||
|
"""
|
||||||
|
Add special tokens to tokenizer and resize the token embeddings
|
||||||
|
"""
|
||||||
|
newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
|
||||||
|
if newly_added_num > 0:
|
||||||
|
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||||
|
|
||||||
|
def prepare_inputs_for_inference(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
encoder_outputs: torch.Tensor,
|
||||||
|
past=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache: bool = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_ids: (batch_size, sequence_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
input_ids: (batch_size, sequence_length)
|
||||||
|
attention_mask: (batch_size, sequence_length)
|
||||||
|
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
|
||||||
|
"""
|
||||||
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
||||||
|
past = past or past_key_values
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
output = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"encoder_hidden_states": encoder_outputs.last_hidden_state,
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = None,
|
||||||
|
output_attentions: Optional[torch.Tensor] = None,
|
||||||
|
output_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = None,
|
||||||
|
):
|
||||||
|
return self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Resize position embeddings
|
||||||
|
Truncate if sequence length of MBart backbone is greater than given max_length,
|
||||||
|
else interpolate to max_length
|
||||||
|
"""
|
||||||
|
if weight.shape[0] > max_length:
|
||||||
|
weight = weight[:max_length, ...]
|
||||||
|
else:
|
||||||
|
weight = (
|
||||||
|
F.interpolate(
|
||||||
|
weight.permute(1, 0).unsqueeze(0),
|
||||||
|
size=max_length,
|
||||||
|
mode="linear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.permute(1, 0)
|
||||||
|
)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
class DonutConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoder_layer: int = 10,
|
||||||
|
max_position_embeddings: int = None,
|
||||||
|
max_length: int = 4096,
|
||||||
|
hidden_dimension: int = 1024,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder_layer = decoder_layer
|
||||||
|
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
|
||||||
|
self.max_length = max_length
|
||||||
|
self.hidden_dimension = hidden_dimension
|
||||||
|
|
||||||
|
|
||||||
|
class RunningVarTorch:
|
||||||
|
def __init__(self, L=15, norm=False):
|
||||||
|
self.values = None
|
||||||
|
self.L = L
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
|
def push(self, x: torch.Tensor):
|
||||||
|
assert x.dim() == 1
|
||||||
|
if self.values is None:
|
||||||
|
self.values = x[:, None]
|
||||||
|
elif self.values.shape[1] < self.L:
|
||||||
|
self.values = torch.cat((self.values, x[:, None]), 1)
|
||||||
|
else:
|
||||||
|
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
||||||
|
|
||||||
|
def variance(self):
|
||||||
|
if self.values is None:
|
||||||
|
return
|
||||||
|
if self.norm:
|
||||||
|
return torch.var(self.values, 1) / self.values.shape[1]
|
||||||
|
else:
|
||||||
|
return torch.var(self.values, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteriaScores(StoppingCriteria):
|
||||||
|
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
||||||
|
super().__init__()
|
||||||
|
self.threshold = threshold
|
||||||
|
self.vars = RunningVarTorch(norm=True)
|
||||||
|
self.varvars = RunningVarTorch(L=window_size)
|
||||||
|
self.stop_inds = defaultdict(int)
|
||||||
|
self.stopped = defaultdict(bool)
|
||||||
|
self.size = 0
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||||
|
last_scores = scores[-1]
|
||||||
|
self.vars.push(last_scores.max(1)[0].float().cpu())
|
||||||
|
self.varvars.push(self.vars.variance())
|
||||||
|
self.size += 1
|
||||||
|
if self.size < self.window_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
varvar = self.varvars.variance()
|
||||||
|
for b in range(len(last_scores)):
|
||||||
|
if varvar[b] < self.threshold:
|
||||||
|
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
||||||
|
self.stopped[b] = self.stop_inds[b] >= self.size
|
||||||
|
else:
|
||||||
|
self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
|
||||||
|
else:
|
||||||
|
self.stop_inds[b] = 0
|
||||||
|
self.stopped[b] = False
|
||||||
|
return all(self.stopped.values()) and len(self.stopped) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def batch(l, b=15):
|
||||||
|
subs = []
|
||||||
|
for i in range(len(l) - b):
|
||||||
|
subs.append(l[i : i + b])
|
||||||
|
return subs
|
||||||
|
|
||||||
|
|
||||||
|
def subdiv(l, b=10):
|
||||||
|
subs = []
|
||||||
|
for i in range(len(l) - b):
|
||||||
|
subs.append(l[: i + b])
|
||||||
|
return subs
|
||||||
|
|
||||||
|
|
||||||
|
class DonutModel(PreTrainedModel):
|
||||||
|
config_class = DonutConfig
|
||||||
|
base_model_prefix = "donut"
|
||||||
|
|
||||||
|
def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.vpm = vision_tower
|
||||||
|
|
||||||
|
# build language model
|
||||||
|
self.llm = BARTDecoder(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
decoder_layer=self.config.decoder_layer,
|
||||||
|
max_position_embeddings=self.config.max_position_embeddings,
|
||||||
|
hidden_dimension=self.config.hidden_dimension,
|
||||||
|
)
|
||||||
|
self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
|
||||||
|
|
||||||
|
def get_input_embeddings(self, tensor):
|
||||||
|
return self.llm.model.get_input_embeddings()(tensor)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: dict,
|
||||||
|
):
|
||||||
|
image_tensors = inputs["pixel_values"]
|
||||||
|
input_ids = inputs["input_ids"].contiguous()
|
||||||
|
attention_mask = inputs["attention_mask"]
|
||||||
|
labels = inputs["labels"].contiguous()
|
||||||
|
|
||||||
|
encoder_outputs = self.vpm(
|
||||||
|
image_tensors,
|
||||||
|
text_embedding=self.llm.model.get_input_embeddings()(input_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_outputs = self.llm(
|
||||||
|
input_ids=input_ids,
|
||||||
|
encoder_hidden_states=encoder_outputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
return decoder_outputs
|
||||||
|
|
||||||
|
def get_hidden_states_during_inference(
|
||||||
|
self,
|
||||||
|
prompt_ids: torch.Tensor,
|
||||||
|
image: Image.Image = None,
|
||||||
|
image_tensors: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if image_tensors is None:
|
||||||
|
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
||||||
|
|
||||||
|
if self.device.type != "mps":
|
||||||
|
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
||||||
|
|
||||||
|
image_tensors = image_tensors.to(self.device)
|
||||||
|
prompt_ids = prompt_ids.to(self.device)
|
||||||
|
all_hidden_states = self.vpm.forward_features(
|
||||||
|
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
|
||||||
|
)
|
||||||
|
return all_hidden_states
|
||||||
|
|
||||||
|
def get_attn_weights_during_inference(
|
||||||
|
self,
|
||||||
|
prompt_ids: torch.Tensor,
|
||||||
|
image: Image.Image = None,
|
||||||
|
image_tensors: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if image_tensors is None:
|
||||||
|
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
||||||
|
|
||||||
|
if self.device.type != "mps":
|
||||||
|
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
||||||
|
|
||||||
|
image_tensors = image_tensors.to(self.device)
|
||||||
|
prompt_ids = prompt_ids.to(self.device)
|
||||||
|
last_attn_score = self.vpm.get_last_layer_cross_attn_score(
|
||||||
|
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
|
||||||
|
)
|
||||||
|
return last_attn_score
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
prompt_ids: torch.Tensor,
|
||||||
|
image: Image.Image = None,
|
||||||
|
image_tensors: Optional[torch.Tensor] = None,
|
||||||
|
return_attentions: bool = False,
|
||||||
|
early_stopping: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a token sequence in an auto-regressive manner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: input document image (PIL.Image)
|
||||||
|
image_tensors: (1, num_channels, height, width)
|
||||||
|
convert prompt to tensor if image_tensor is not fed
|
||||||
|
"""
|
||||||
|
output = {
|
||||||
|
"predictions": list(),
|
||||||
|
"sequences": list(),
|
||||||
|
"repeats": list(),
|
||||||
|
"repetitions": list(),
|
||||||
|
}
|
||||||
|
if image is None and image_tensors is None:
|
||||||
|
logging.warn("Image not found")
|
||||||
|
return output
|
||||||
|
|
||||||
|
if image_tensors is None:
|
||||||
|
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
||||||
|
|
||||||
|
if self.device.type != "mps":
|
||||||
|
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
||||||
|
|
||||||
|
image_tensors = image_tensors.to(self.device)
|
||||||
|
prompt_ids = prompt_ids.to(self.device)
|
||||||
|
last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
|
||||||
|
|
||||||
|
encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
|
||||||
|
if len(encoder_outputs.last_hidden_state.size()) == 1:
|
||||||
|
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
|
||||||
|
|
||||||
|
# get decoder output
|
||||||
|
decoder_output = self.llm.model.generate(
|
||||||
|
input_ids=prompt_ids,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
min_length=1,
|
||||||
|
max_length=self.config.max_length,
|
||||||
|
pad_token_id=self.llm.tokenizer.pad_token_id,
|
||||||
|
eos_token_id=self.llm.tokenizer.eos_token_id,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
output_attentions=return_attentions,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
|
||||||
|
)
|
||||||
|
|
||||||
|
output["repetitions"] = decoder_output.sequences.clone()
|
||||||
|
output["sequences"] = decoder_output.sequences.clone()
|
||||||
|
output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
|
||||||
|
|
||||||
|
output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
|
||||||
|
return output
|
64
utils/processor.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import ImageOps
|
||||||
|
|
||||||
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class DolphinProcessor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dp_config,
|
||||||
|
tokenizer,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
transform_args = kwargs.get("transform_args", {})
|
||||||
|
self.max_length = transform_args.get("max_length", 2048)
|
||||||
|
self.input_size = transform_args.get("input_size", [896, 896]) # height, width
|
||||||
|
if isinstance(self.input_size, int):
|
||||||
|
self.input_size = [self.input_size, self.input_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.answer_start_token = self.tokenizer._prompt_end_token
|
||||||
|
except AttributeError as err:
|
||||||
|
print('No answer_start_token found, use "" instead')
|
||||||
|
self.answer_start_token = ""
|
||||||
|
|
||||||
|
self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
|
||||||
|
self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
|
||||||
|
|
||||||
|
def process_prompt_for_inference(self, prompt):
|
||||||
|
prompt = prompt.replace("<image>\n", "")
|
||||||
|
if not prompt.startswith("<s>"):
|
||||||
|
prompt = "<s>" + prompt
|
||||||
|
message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
|
||||||
|
ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
|
||||||
|
return ids.unsqueeze(0)
|
||||||
|
|
||||||
|
def process_image_for_inference(self, image, return_img_size=False):
|
||||||
|
image = resize(image, min(self.input_size))
|
||||||
|
|
||||||
|
image.thumbnail((self.input_size[1], self.input_size[0]))
|
||||||
|
origin_w, origin_h = image.size
|
||||||
|
|
||||||
|
delta_width = self.input_size[1] - image.width
|
||||||
|
delta_height = self.input_size[0] - image.height
|
||||||
|
pad_width = delta_width // 2
|
||||||
|
pad_height = delta_height // 2
|
||||||
|
padding = (
|
||||||
|
pad_width,
|
||||||
|
pad_height,
|
||||||
|
delta_width - pad_width,
|
||||||
|
delta_height - pad_height,
|
||||||
|
)
|
||||||
|
image = ImageOps.expand(image, padding)
|
||||||
|
if return_img_size:
|
||||||
|
return test_transform(image).unsqueeze(0), (origin_w, origin_h)
|
||||||
|
return test_transform(image).unsqueeze(0)
|
367
utils/utils.py
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
SPDX-License-Identifier: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import albumentations as alb
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from PIL import Image
|
||||||
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from torchvision.transforms.functional import resize
|
||||||
|
|
||||||
|
from utils.markdown_utils import MarkdownConverter
|
||||||
|
|
||||||
|
|
||||||
|
def alb_wrapper(transform):
|
||||||
|
def f(im):
|
||||||
|
return transform(image=np.asarray(im))["image"]
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
test_transform = alb_wrapper(
|
||||||
|
alb.Compose(
|
||||||
|
[
|
||||||
|
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
||||||
|
ToTensorV2(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
|
||||||
|
# print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||||
|
if x1 < 0 or y1 < 0:
|
||||||
|
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||||
|
if not abs_coord:
|
||||||
|
if x2 > 1 or y2 > 1:
|
||||||
|
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||||
|
elif image_size is not None: # has image size
|
||||||
|
if x2 > image_size[0] or y2 > image_size[1]:
|
||||||
|
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
|
||||||
|
"""
|
||||||
|
Image: cv2.image object, or Path
|
||||||
|
Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
|
||||||
|
"""
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = cv2.imread(image)
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
new_boxes = []
|
||||||
|
for box in boxes:
|
||||||
|
best_box = copy.deepcopy(box)
|
||||||
|
|
||||||
|
def check_edge(img, current_box, i, is_vertical):
|
||||||
|
edge = current_box[i]
|
||||||
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
if is_vertical:
|
||||||
|
line = binary[current_box[1] : current_box[3] + 1, edge]
|
||||||
|
else:
|
||||||
|
line = binary[edge, current_box[0] : current_box[2] + 1]
|
||||||
|
|
||||||
|
transitions = np.abs(np.diff(line))
|
||||||
|
return np.sum(transitions) / len(transitions)
|
||||||
|
|
||||||
|
# Only widen the box
|
||||||
|
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
|
||||||
|
|
||||||
|
current_box = copy.deepcopy(box)
|
||||||
|
# make sure the box is within the image
|
||||||
|
current_box[0] = min(max(current_box[0], 0), img_w - 1)
|
||||||
|
current_box[1] = min(max(current_box[1], 0), img_h - 1)
|
||||||
|
current_box[2] = min(max(current_box[2], 0), img_w - 1)
|
||||||
|
current_box[3] = min(max(current_box[3], 0), img_h - 1)
|
||||||
|
|
||||||
|
for i, direction, is_vertical in edges:
|
||||||
|
best_score = check_edge(image, current_box, i, is_vertical)
|
||||||
|
if best_score <= threshold:
|
||||||
|
continue
|
||||||
|
for step in range(max_pixels):
|
||||||
|
current_box[i] += direction
|
||||||
|
if i == 0 or i == 2:
|
||||||
|
current_box[i] = min(max(current_box[i], 0), img_w - 1)
|
||||||
|
else:
|
||||||
|
current_box[i] = min(max(current_box[i], 0), img_h - 1)
|
||||||
|
score = check_edge(image, current_box, i, is_vertical)
|
||||||
|
|
||||||
|
if score < best_score:
|
||||||
|
best_score = score
|
||||||
|
best_box = copy.deepcopy(current_box)
|
||||||
|
|
||||||
|
if score <= threshold:
|
||||||
|
break
|
||||||
|
new_boxes.append(best_box)
|
||||||
|
|
||||||
|
return new_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def parse_layout_string(bbox_str):
|
||||||
|
"""Parse layout string using regular expressions"""
|
||||||
|
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
|
||||||
|
matches = re.finditer(pattern, bbox_str)
|
||||||
|
|
||||||
|
parsed_results = []
|
||||||
|
for match in matches:
|
||||||
|
coords = [float(match.group(i)) for i in range(1, 5)]
|
||||||
|
label = match.group(5).strip()
|
||||||
|
parsed_results.append((coords, label))
|
||||||
|
|
||||||
|
return parsed_results
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageDimensions:
|
||||||
|
"""Class to store image dimensions"""
|
||||||
|
original_w: int
|
||||||
|
original_h: int
|
||||||
|
padded_w: int
|
||||||
|
padded_h: int
|
||||||
|
|
||||||
|
|
||||||
|
def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
|
||||||
|
"""Map coordinates from padded image back to original image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x1, y1, x2, y2: Coordinates in padded image
|
||||||
|
dims: Image dimensions object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (x1, y1, x2, y2) coordinates in original image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Calculate padding offsets
|
||||||
|
top = (dims.padded_h - dims.original_h) // 2
|
||||||
|
left = (dims.padded_w - dims.original_w) // 2
|
||||||
|
|
||||||
|
# Map back to original coordinates
|
||||||
|
orig_x1 = max(0, x1 - left)
|
||||||
|
orig_y1 = max(0, y1 - top)
|
||||||
|
orig_x2 = min(dims.original_w, x2 - left)
|
||||||
|
orig_y2 = min(dims.original_h, y2 - top)
|
||||||
|
|
||||||
|
# Ensure we have a valid box (width and height > 0)
|
||||||
|
if orig_x2 <= orig_x1:
|
||||||
|
orig_x2 = min(orig_x1 + 1, dims.original_w)
|
||||||
|
if orig_y2 <= orig_y1:
|
||||||
|
orig_y2 = min(orig_y1 + 1, dims.original_h)
|
||||||
|
|
||||||
|
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"map_to_original_coordinates error: {str(e)}")
|
||||||
|
# Return safe coordinates
|
||||||
|
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||||
|
|
||||||
|
|
||||||
|
def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
|
||||||
|
"""
|
||||||
|
From absolute coordinates to relevant coordinates
|
||||||
|
e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
x1, y1, x2, y2 = abs_coords
|
||||||
|
return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"map_to_relevant_coordinates error: {str(e)}")
|
||||||
|
return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates
|
||||||
|
|
||||||
|
|
||||||
|
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
|
||||||
|
"""Process and adjust coordinates
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coords: Normalized coordinates [x1, y1, x2, y2]
|
||||||
|
padded_image: Padded image
|
||||||
|
dims: Image dimensions object
|
||||||
|
previous_box: Previous box coordinates for overlap adjustment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert normalized coordinates to absolute coordinates
|
||||||
|
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
|
||||||
|
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
|
||||||
|
|
||||||
|
# Ensure coordinates are within image bounds before adjustment
|
||||||
|
x1 = max(0, min(x1, dims.padded_w - 1))
|
||||||
|
y1 = max(0, min(y1, dims.padded_h - 1))
|
||||||
|
x2 = max(0, min(x2, dims.padded_w))
|
||||||
|
y2 = max(0, min(y2, dims.padded_h))
|
||||||
|
|
||||||
|
# Ensure width and height are at least 1 pixel
|
||||||
|
if x2 <= x1:
|
||||||
|
x2 = min(x1 + 1, dims.padded_w)
|
||||||
|
if y2 <= y1:
|
||||||
|
y2 = min(y1 + 1, dims.padded_h)
|
||||||
|
|
||||||
|
# Extend box boundaries
|
||||||
|
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
|
||||||
|
x1, y1, x2, y2 = new_boxes[0]
|
||||||
|
|
||||||
|
# Ensure coordinates are still within image bounds after adjustment
|
||||||
|
x1 = max(0, min(x1, dims.padded_w - 1))
|
||||||
|
y1 = max(0, min(y1, dims.padded_h - 1))
|
||||||
|
x2 = max(0, min(x2, dims.padded_w))
|
||||||
|
y2 = max(0, min(y2, dims.padded_h))
|
||||||
|
|
||||||
|
# Ensure width and height are at least 1 pixel after adjustment
|
||||||
|
if x2 <= x1:
|
||||||
|
x2 = min(x1 + 1, dims.padded_w)
|
||||||
|
if y2 <= y1:
|
||||||
|
y2 = min(y1 + 1, dims.padded_h)
|
||||||
|
|
||||||
|
# Check for overlap with previous box and adjust
|
||||||
|
if previous_box is not None:
|
||||||
|
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
|
||||||
|
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
|
||||||
|
y1 = prev_y2
|
||||||
|
# Ensure y1 is still valid
|
||||||
|
y1 = min(y1, dims.padded_h - 1)
|
||||||
|
# Make sure y2 is still greater than y1
|
||||||
|
if y2 <= y1:
|
||||||
|
y2 = min(y1 + 1, dims.padded_h)
|
||||||
|
|
||||||
|
# Update previous box
|
||||||
|
new_previous_box = [x1, y1, x2, y2]
|
||||||
|
|
||||||
|
# Map to original coordinates
|
||||||
|
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
|
||||||
|
x1, y1, x2, y2, dims
|
||||||
|
)
|
||||||
|
|
||||||
|
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
|
||||||
|
except Exception as e:
|
||||||
|
print(f"process_coordinates error: {str(e)}")
|
||||||
|
# Return safe values
|
||||||
|
orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||||
|
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
|
||||||
|
"""Load and prepare image with padding while maintaining aspect ratio
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (padded_image, image_dimensions)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert PIL image to OpenCV format
|
||||||
|
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||||
|
original_h, original_w = image.shape[:2]
|
||||||
|
|
||||||
|
# Calculate padding to make square image
|
||||||
|
max_size = max(original_h, original_w)
|
||||||
|
top = (max_size - original_h) // 2
|
||||||
|
bottom = max_size - original_h - top
|
||||||
|
left = (max_size - original_w) // 2
|
||||||
|
right = max_size - original_w - left
|
||||||
|
|
||||||
|
# Apply padding
|
||||||
|
padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
|
||||||
|
cv2.BORDER_CONSTANT, value=(0, 0, 0))
|
||||||
|
|
||||||
|
padded_h, padded_w = padded_image.shape[:2]
|
||||||
|
|
||||||
|
dimensions = ImageDimensions(
|
||||||
|
original_w=original_w,
|
||||||
|
original_h=original_h,
|
||||||
|
padded_w=padded_w,
|
||||||
|
padded_h=padded_h
|
||||||
|
)
|
||||||
|
|
||||||
|
return padded_image, dimensions
|
||||||
|
except Exception as e:
|
||||||
|
print(f"prepare_image error: {str(e)}")
|
||||||
|
# Create a minimal valid image and dimensions
|
||||||
|
h, w = image.height, image.width
|
||||||
|
dimensions = ImageDimensions(
|
||||||
|
original_w=w,
|
||||||
|
original_h=h,
|
||||||
|
padded_w=w,
|
||||||
|
padded_h=h
|
||||||
|
)
|
||||||
|
# Return a black image of the same size
|
||||||
|
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def setup_output_dirs(save_dir):
|
||||||
|
"""Create necessary output directories"""
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def save_outputs(recognition_results, image_path, save_dir):
|
||||||
|
"""Save JSON and markdown outputs"""
|
||||||
|
basename = os.path.splitext(os.path.basename(image_path))[0]
|
||||||
|
|
||||||
|
# Save JSON file
|
||||||
|
json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(recognition_results, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# Generate and save markdown file
|
||||||
|
markdown_converter = MarkdownConverter()
|
||||||
|
markdown_content = markdown_converter.convert(recognition_results)
|
||||||
|
markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
|
||||||
|
with open(markdown_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(markdown_content)
|
||||||
|
|
||||||
|
return json_path
|
||||||
|
|
||||||
|
|
||||||
|
def crop_margin(img: Image.Image) -> Image.Image:
|
||||||
|
"""Crop margins from image"""
|
||||||
|
try:
|
||||||
|
width, height = img.size
|
||||||
|
if width == 0 or height == 0:
|
||||||
|
print("Warning: Image has zero width or height")
|
||||||
|
return img
|
||||||
|
|
||||||
|
data = np.array(img.convert("L"))
|
||||||
|
data = data.astype(np.uint8)
|
||||||
|
max_val = data.max()
|
||||||
|
min_val = data.min()
|
||||||
|
if max_val == min_val:
|
||||||
|
return img
|
||||||
|
data = (data - min_val) / (max_val - min_val) * 255
|
||||||
|
gray = 255 * (data < 200).astype(np.uint8)
|
||||||
|
|
||||||
|
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
||||||
|
if coords is None:
|
||||||
|
return img
|
||||||
|
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
||||||
|
|
||||||
|
# Ensure crop coordinates are within image bounds
|
||||||
|
a = max(0, a)
|
||||||
|
b = max(0, b)
|
||||||
|
w = min(w, width - a)
|
||||||
|
h = min(h, height - b)
|
||||||
|
|
||||||
|
# Only crop if we have a valid region
|
||||||
|
if w > 0 and h > 0:
|
||||||
|
return img.crop((a, b, a + w, b + h))
|
||||||
|
return img
|
||||||
|
except Exception as e:
|
||||||
|
print(f"crop_margin error: {str(e)}")
|
||||||
|
return img # Return original image on error
|