[init] initial commit

This commit is contained in:
fenghao.2019
2025-05-26 23:20:51 +08:00
commit 49f51871c6
31 changed files with 2757 additions and 0 deletions

442
utils/markdown_utils.py Normal file
View File

@@ -0,0 +1,442 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import re
import base64
from typing import List, Dict, Any, Optional
"""
Example input:
[
{"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6},
{"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7},
{"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8},
{"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9},
{"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
]
"""
def extract_table_from_html(html_string):
"""Extract and clean table tags from HTML string"""
try:
table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
tables = table_pattern.findall(html_string)
tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
return '\n'.join(tables)
except Exception as e:
print(f"extract_table_from_html error: {str(e)}")
return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
class MarkdownConverter:
"""Convert structured recognition results to Markdown format"""
def __init__(self):
# Define heading levels for different section types
self.heading_levels = {
'title': '#',
'sec': '##',
'sub_sec': '###'
}
# Define which labels need special handling
self.special_labels = {
'tab', 'fig', 'title', 'sec', 'sub_sec',
'list', 'formula', 'reference', 'alg'
}
def try_remove_newline(self, text: str) -> str:
try:
# Preprocess text to handle line breaks
text = text.strip()
text = text.replace('-\n', '')
# Handle Chinese text line breaks
def is_chinese(char):
return '\u4e00' <= char <= '\u9fff'
lines = text.split('\n')
processed_lines = []
# Process all lines except the last one
for i in range(len(lines)-1):
current_line = lines[i].strip()
next_line = lines[i+1].strip()
# Always add the current line, but determine if we need a newline
if current_line: # If current line is not empty
if next_line: # If next line is not empty
# For Chinese text handling
if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
processed_lines.append(current_line)
else:
processed_lines.append(current_line + ' ')
else:
# Next line is empty, add current line with newline
processed_lines.append(current_line + '\n')
else:
# Current line is empty, add an empty line
processed_lines.append('\n')
# Add the last line
if lines and lines[-1].strip():
processed_lines.append(lines[-1].strip())
text = ''.join(processed_lines)
return text
except Exception as e:
print(f"try_remove_newline error: {str(e)}")
return text # Return original text on error
def _handle_text(self, text: str) -> str:
"""
Process regular text content, preserving paragraph structure
"""
try:
if not text:
return ""
if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
text = "$$" + text + "$$"
elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
text = "$" + text + "$"
# Process formulas in text before handling other text processing
text = self._process_formulas_in_text(text)
text = self.try_remove_newline(text)
# Return processed text
return text
except Exception as e:
print(f"_handle_text error: {str(e)}")
return text # Return original text on error
def _process_formulas_in_text(self, text: str) -> str:
"""
Process mathematical formulas in text by iteratively finding and replacing formulas.
- Identify inline and block formulas
- Replace newlines within formulas with \\
"""
try:
# Define formula delimiters and their corresponding patterns
delimiters = [
('$$', '$$'), # Block formula with $$
('\\[', '\\]'), # Block formula with \[ \]
('$', '$'), # Inline formula with $
('\\(', '\\)') # Inline formula with \( \)
]
# Process the text by iterating through each delimiter type
result = text
for start_delim, end_delim in delimiters:
# Create a pattern that matches from start to end delimiter
# Using a custom approach to avoid issues with nested delimiters
current_pos = 0
processed_parts = []
while current_pos < len(result):
# Find the next start delimiter
start_pos = result.find(start_delim, current_pos)
if start_pos == -1:
# No more formulas of this type
processed_parts.append(result[current_pos:])
break
# Add text before the formula
processed_parts.append(result[current_pos:start_pos])
# Find the matching end delimiter
end_pos = result.find(end_delim, start_pos + len(start_delim))
if end_pos == -1:
# No matching end delimiter, treat as regular text
processed_parts.append(result[start_pos:])
break
# Extract the formula content (without delimiters)
formula_content = result[start_pos + len(start_delim):end_pos]
# Process the formula content - replace newlines with \\
processed_formula = formula_content.replace('\n', ' \\\\ ')
# Add the processed formula with its delimiters
processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
# Move past this formula
current_pos = end_pos + len(end_delim)
# Update the result with processed text
result = ''.join(processed_parts)
return result
except Exception as e:
print(f"_process_formulas_in_text error: {str(e)}")
return text # Return original text on error
def _remove_newline_in_heading(self, text: str) -> str:
"""
Remove newline in heading
"""
try:
# Handle Chinese text line breaks
def is_chinese(char):
return '\u4e00' <= char <= '\u9fff'
# Check if the text contains Chinese characters
if any(is_chinese(char) for char in text):
return text.replace('\n', '')
else:
return text.replace('\n', ' ')
except Exception as e:
print(f"_remove_newline_in_heading error: {str(e)}")
return text
def _handle_heading(self, text: str, label: str) -> str:
"""
Convert section headings to appropriate markdown format
"""
try:
level = self.heading_levels.get(label, '#')
text = text.strip()
text = self._remove_newline_in_heading(text)
text = self._handle_text(text)
return f"{level} {text}\n\n"
except Exception as e:
print(f"_handle_heading error: {str(e)}")
return f"# Error processing heading: {text}\n\n"
def _handle_list_item(self, text: str) -> str:
"""
Convert list items to markdown list format
"""
try:
return f"- {text.strip()}\n"
except Exception as e:
print(f"_handle_list_item error: {str(e)}")
return f"- Error processing list item: {text}\n"
def _handle_figure(self, text: str, section_count: int) -> str:
"""
Convert base64 encoded image to markdown image syntax
"""
try:
# Determine image format (assuming PNG if not specified)
img_format = "png"
if text.startswith("data:image/"):
# Extract format from data URI
img_format = text.split(";")[0].split("/")[1]
elif ";" in text and "," in text:
# Already in data URI format
return f"![Figure {section_count}]({text})\n\n"
else:
# Raw base64, convert to data URI
data_uri = f"data:image/{img_format};base64,{text}"
return f"![Figure {section_count}]({data_uri})\n\n"
except Exception as e:
print(f"_handle_figure error: {str(e)}")
return f"*[Error processing figure: {str(e)}]*\n\n"
def _handle_table(self, text: str) -> str:
"""
Convert table content to markdown format
"""
try:
markdown_content = []
if '<table' in text.lower() or '<tr' in text.lower():
markdown_table = extract_table_from_html(text)
markdown_content.append(markdown_table + "\n")
else:
table_lines = text.split('\n')
if table_lines:
col_count = len(table_lines[0].split()) if table_lines[0] else 1
header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
markdown_content.append(header)
markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |')
for line in table_lines[1:]:
cells = line.split()
while len(cells) < col_count:
cells.append('')
markdown_content.append('| ' + ' | '.join(cells) + ' |')
return '\n'.join(markdown_content) + '\n\n'
except Exception as e:
print(f"_handle_table error: {str(e)}")
return f"*[Error processing table: {str(e)}]*\n\n"
def _handle_algorithm(self, text: str) -> str:
"""
Process algorithm blocks with proper formatting
"""
try:
# Remove algorithm environment tags if present
text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
# Process the algorithm text
lines = text.strip().split('\n')
# Check if there's a caption or label
caption = ""
algorithm_text = []
for line in lines:
if '\\caption' in line:
# Extract caption text
caption_match = re.search(r'\\caption\{(.*?)\}', line)
if caption_match:
caption = f"**{caption_match.group(1)}**\n\n"
continue
elif '\\label' in line:
continue # Skip label lines
else:
algorithm_text.append(line)
# Join the algorithm text and wrap in code block
formatted_text = '\n'.join(algorithm_text)
# Return the formatted algorithm with caption
return f"{caption}```\n{formatted_text}\n```\n\n"
except Exception as e:
print(f"_handle_algorithm error: {str(e)}")
return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
def _handle_formula(self, text: str) -> str:
"""
Handle formula-specific content
"""
try:
# Process the formula content
processed_text = self._process_formulas_in_text(text)
# For formula blocks, ensure they're properly formatted in markdown
if '$$' not in processed_text and '\\[' not in processed_text:
# If no block formula delimiters are present, wrap in $$ for block formula
processed_text = f'$${processed_text}$$'
return f"{processed_text}\n\n"
except Exception as e:
print(f"_handle_formula error: {str(e)}")
return f"*[Error processing formula: {str(e)}]*\n\n"
def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
"""
Convert recognition results to markdown format
"""
try:
markdown_content = []
for section_count, result in enumerate(recognition_results):
try:
label = result.get('label', '')
text = result.get('text', '').strip()
# Skip empty text
if not text:
continue
# Handle different content types
if label in {'title', 'sec', 'sub_sec'}:
markdown_content.append(self._handle_heading(text, label))
elif label == 'list':
markdown_content.append(self._handle_list_item(text))
elif label == 'fig':
markdown_content.append(self._handle_figure(text, section_count))
elif label == 'tab':
markdown_content.append(self._handle_table(text))
elif label == 'alg':
markdown_content.append(self._handle_algorithm(text))
elif label == 'formula':
markdown_content.append(self._handle_formula(text))
elif label not in self.special_labels:
# Handle regular text (paragraphs, etc.)
processed_text = self._handle_text(text)
markdown_content.append(f"{processed_text}\n\n")
except Exception as e:
print(f"Error processing item {section_count}: {str(e)}")
# Add a placeholder for the failed item
markdown_content.append(f"*[Error processing content]*\n\n")
# Join all content and apply post-processing
result = ''.join(markdown_content)
return self._post_process(result)
except Exception as e:
print(f"convert error: {str(e)}")
return f"Error generating markdown content: {str(e)}"
def _post_process(self, markdown_content: str) -> str:
"""
Apply post-processing fixes to the generated markdown content
"""
try:
# Handle author information
author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
def process_author_match(match):
# Extract author content
author_content = match.group(1)
# Process the author content
return self._handle_text(author_content)
# Replace \author{...} with processed content
markdown_content = author_pattern.sub(process_author_match, markdown_content)
# Handle special case where author is inside math environment
math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL)
match = math_author_pattern.search(markdown_content)
if match:
# Extract the author command
author_cmd = match.group(1)
# Extract content from author command
author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
if author_content_match:
# Get author content and process it
author_content = author_content_match.group(1)
processed_content = self._handle_text(author_content)
# Replace the entire $\author{...}$ block with processed content
markdown_content = markdown_content.replace(match.group(0), processed_content)
# Replace LaTeX abstract environment with plain text
markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}',
r'**Abstract** \1',
markdown_content,
flags=re.DOTALL)
# Replace standalone \begin{abstract} (without matching end)
markdown_content = re.sub(r'\\begin\{abstract\}',
r'**Abstract**',
markdown_content)
# Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
r'\\tag{\1}',
markdown_content)
# Find the starting tag of the formula
markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\")
# Find the ending tag of the formula (ensure this is the only ending tag)
markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$")
# Fix other common LaTeX issues
replacements = [
# Fix spacing issues in subscripts and superscripts
(r'_ {', r'_{'),
(r'^ {', r'^{'),
# Fix potential issues with multiple consecutive newlines
(r'\n{3,}', r'\n\n')
]
for old, new in replacements:
markdown_content = re.sub(old, new, markdown_content)
return markdown_content
except Exception as e:
print(f"_post_process error: {str(e)}")
return markdown_content # Return original content if post-processing fails

477
utils/model.py Normal file
View File

@@ -0,0 +1,477 @@
"""
Copyright (c) 2022-present NAVER Corp.
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
MIT License
This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
This modified file is released under the same license.
"""
import logging
from collections import defaultdict
from typing import List, Optional
import torch
import torch.nn.functional as F
from PIL import Image
from timm.models.swin_transformer import SwinTransformer
from torch import nn
from transformers import (
MBartConfig,
MBartForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers.file_utils import ModelOutput
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
class SwinEncoder(nn.Module):
r"""
Encoder based on SwinTransformer
Set the initial weights and configuration with a pretrained SwinTransformer and then
modify the detailed configurations
Args:
input_size: Input image size (width, height)
align_long_axis: Whether to rotate image if height is greater than width
window_size: Window size(=patch size) of SwinTransformer
encoder_layer: Number of layers of SwinTransformer encoder
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
"""
def __init__(
self,
input_size,
align_long_axis: bool = False,
window_size: int = 7,
encoder_layer: List[int] = [2, 2, 14, 2],
patch_size: int = [4, 4],
embed_dim: int = 128,
num_heads: List[int] = [4, 8, 16, 32],
):
super().__init__()
if isinstance(input_size, int):
input_size = [input_size, input_size]
self.input_size = input_size
self.align_long_axis = align_long_axis
self.window_size = window_size
self.encoder_layer = encoder_layer
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.model = SwinTransformer(
img_size=self.input_size,
depths=self.encoder_layer,
window_size=self.window_size,
patch_size=self.patch_size,
embed_dim=self.embed_dim,
num_heads=self.num_heads,
num_classes=0,
)
def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: (batch_size, num_channels, height, width)
"""
x = self.model.patch_embed(x)
x = self.model.pos_drop(x)
x = self.model.layers(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def _set_dtype(self, dtype):
self._dtype = dtype
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(dtype=self._dtype))
return ret.type(orig_type)
class BARTDecoder(nn.Module):
"""
Decoder based on Multilingual BART
Set the initial weights and configuration with a pretrained multilingual BART model,
and modify the detailed configurations as a Donut decoder
Args:
decoder_layer:
Number of layers of BARTDecoder
max_position_embeddings:
The maximum sequence length to be trained
name_or_path:
Name of a pretrained model name either registered in huggingface.co. or saved in local,
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
"""
def __init__(
self,
tokenizer,
decoder_layer: int,
max_position_embeddings: int,
hidden_dimension: int = 1024,
**kwargs,
):
super().__init__()
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_position_embeddings
self.hidden_dimension = hidden_dimension
self.tokenizer = tokenizer
self.model = MBartForCausalLM(
config=MBartConfig(
tie_word_embeddings=True,
is_decoder=True,
is_encoder_decoder=False,
add_cross_attention=True,
decoder_layers=self.decoder_layer,
max_position_embeddings=self.max_position_embeddings,
vocab_size=len(self.tokenizer),
scale_embedding=True,
add_final_layer_norm=True,
d_model=self.hidden_dimension,
)
)
# self.model.config.is_encoder_decoder = True # to get cross-attention
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
def add_special_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def add_tokens(self, list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def prepare_inputs_for_inference(
self,
input_ids: torch.Tensor,
encoder_outputs: torch.Tensor,
past=None,
past_key_values=None,
use_cache: bool = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
"""
Args:
input_ids: (batch_size, sequence_length)
Returns:
input_ids: (batch_size, sequence_length)
attention_mask: (batch_size, sequence_length)
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
"""
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
past = past or past_key_values
if past is not None:
input_ids = input_ids[:, -1:]
output = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
"encoder_hidden_states": encoder_outputs.last_hidden_state,
}
return output
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: bool = None,
output_attentions: Optional[torch.Tensor] = None,
output_hidden_states: Optional[torch.Tensor] = None,
return_dict: bool = None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@staticmethod
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
"""
Resize position embeddings
Truncate if sequence length of MBart backbone is greater than given max_length,
else interpolate to max_length
"""
if weight.shape[0] > max_length:
weight = weight[:max_length, ...]
else:
weight = (
F.interpolate(
weight.permute(1, 0).unsqueeze(0),
size=max_length,
mode="linear",
align_corners=False,
)
.squeeze(0)
.permute(1, 0)
)
return weight
class DonutConfig(PretrainedConfig):
def __init__(
self,
decoder_layer: int = 10,
max_position_embeddings: int = None,
max_length: int = 4096,
hidden_dimension: int = 1024,
**kwargs,
):
super().__init__()
self.decoder_layer = decoder_layer
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
self.max_length = max_length
self.hidden_dimension = hidden_dimension
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
def batch(l, b=15):
subs = []
for i in range(len(l) - b):
subs.append(l[i : i + b])
return subs
def subdiv(l, b=10):
subs = []
for i in range(len(l) - b):
subs.append(l[: i + b])
return subs
class DonutModel(PreTrainedModel):
config_class = DonutConfig
base_model_prefix = "donut"
def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
super().__init__(config)
self.config = config
self.tokenizer = tokenizer
self.vpm = vision_tower
# build language model
self.llm = BARTDecoder(
tokenizer=tokenizer,
decoder_layer=self.config.decoder_layer,
max_position_embeddings=self.config.max_position_embeddings,
hidden_dimension=self.config.hidden_dimension,
)
self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
def get_input_embeddings(self, tensor):
return self.llm.model.get_input_embeddings()(tensor)
def forward(
self,
inputs: dict,
):
image_tensors = inputs["pixel_values"]
input_ids = inputs["input_ids"].contiguous()
attention_mask = inputs["attention_mask"]
labels = inputs["labels"].contiguous()
encoder_outputs = self.vpm(
image_tensors,
text_embedding=self.llm.model.get_input_embeddings()(input_ids),
)
decoder_outputs = self.llm(
input_ids=input_ids,
encoder_hidden_states=encoder_outputs,
attention_mask=attention_mask,
labels=labels,
)
return decoder_outputs
def get_hidden_states_during_inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
):
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
all_hidden_states = self.vpm.forward_features(
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
)
return all_hidden_states
def get_attn_weights_during_inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
):
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
last_attn_score = self.vpm.get_last_layer_cross_attn_score(
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
)
return last_attn_score
def inference(
self,
prompt_ids: torch.Tensor,
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
return_attentions: bool = False,
early_stopping: bool = True,
):
"""
Generate a token sequence in an auto-regressive manner.
Args:
image: input document image (PIL.Image)
image_tensors: (1, num_channels, height, width)
convert prompt to tensor if image_tensor is not fed
"""
output = {
"predictions": list(),
"sequences": list(),
"repeats": list(),
"repetitions": list(),
}
if image is None and image_tensors is None:
logging.warn("Image not found")
return output
if image_tensors is None:
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
if self.device.type != "mps":
image_tensors = image_tensors.to(next(self.parameters()).dtype)
image_tensors = image_tensors.to(self.device)
prompt_ids = prompt_ids.to(self.device)
last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
if len(encoder_outputs.last_hidden_state.size()) == 1:
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
# get decoder output
decoder_output = self.llm.model.generate(
input_ids=prompt_ids,
encoder_outputs=encoder_outputs,
min_length=1,
max_length=self.config.max_length,
pad_token_id=self.llm.tokenizer.pad_token_id,
eos_token_id=self.llm.tokenizer.eos_token_id,
use_cache=True,
return_dict_in_generate=True,
output_scores=True,
output_attentions=return_attentions,
do_sample=False,
num_beams=1,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
)
output["repetitions"] = decoder_output.sequences.clone()
output["sequences"] = decoder_output.sequences.clone()
output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
return output

64
utils/processor.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import numpy as np
import torch
from PIL import ImageOps
from utils.utils import *
class DolphinProcessor:
def __init__(
self,
dp_config,
tokenizer,
**kwargs,
) -> None:
self.tokenizer = tokenizer
transform_args = kwargs.get("transform_args", {})
self.max_length = transform_args.get("max_length", 2048)
self.input_size = transform_args.get("input_size", [896, 896]) # height, width
if isinstance(self.input_size, int):
self.input_size = [self.input_size, self.input_size]
try:
self.answer_start_token = self.tokenizer._prompt_end_token
except AttributeError as err:
print('No answer_start_token found, use "" instead')
self.answer_start_token = ""
self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
def process_prompt_for_inference(self, prompt):
prompt = prompt.replace("<image>\n", "")
if not prompt.startswith("<s>"):
prompt = "<s>" + prompt
message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
return ids.unsqueeze(0)
def process_image_for_inference(self, image, return_img_size=False):
image = resize(image, min(self.input_size))
image.thumbnail((self.input_size[1], self.input_size[0]))
origin_w, origin_h = image.size
delta_width = self.input_size[1] - image.width
delta_height = self.input_size[0] - image.height
pad_width = delta_width // 2
pad_height = delta_height // 2
padding = (
pad_width,
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
image = ImageOps.expand(image, padding)
if return_img_size:
return test_transform(image).unsqueeze(0), (origin_w, origin_h)
return test_transform(image).unsqueeze(0)

367
utils/utils.py Normal file
View File

@@ -0,0 +1,367 @@
"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import copy
import json
import os
import re
from dataclasses import dataclass
from typing import List, Tuple
import albumentations as alb
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms.functional import resize
from utils.markdown_utils import MarkdownConverter
def alb_wrapper(transform):
def f(im):
return transform(image=np.asarray(im))["image"]
return f
test_transform = alb_wrapper(
alb.Compose(
[
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
)
)
def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
# print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
if x2 <= x1 or y2 <= y1:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
if x1 < 0 or y1 < 0:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
if not abs_coord:
if x2 > 1 or y2 > 1:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
elif image_size is not None: # has image size
if x2 > image_size[0] or y2 > image_size[1]:
return False, f"[{x1}, {y1}, {x2}, {y2}]"
return True, None
def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
"""
Image: cv2.image object, or Path
Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
"""
if isinstance(image, str):
image = cv2.imread(image)
img_h, img_w = image.shape[:2]
new_boxes = []
for box in boxes:
best_box = copy.deepcopy(box)
def check_edge(img, current_box, i, is_vertical):
edge = current_box[i]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
if is_vertical:
line = binary[current_box[1] : current_box[3] + 1, edge]
else:
line = binary[edge, current_box[0] : current_box[2] + 1]
transitions = np.abs(np.diff(line))
return np.sum(transitions) / len(transitions)
# Only widen the box
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
current_box = copy.deepcopy(box)
# make sure the box is within the image
current_box[0] = min(max(current_box[0], 0), img_w - 1)
current_box[1] = min(max(current_box[1], 0), img_h - 1)
current_box[2] = min(max(current_box[2], 0), img_w - 1)
current_box[3] = min(max(current_box[3], 0), img_h - 1)
for i, direction, is_vertical in edges:
best_score = check_edge(image, current_box, i, is_vertical)
if best_score <= threshold:
continue
for step in range(max_pixels):
current_box[i] += direction
if i == 0 or i == 2:
current_box[i] = min(max(current_box[i], 0), img_w - 1)
else:
current_box[i] = min(max(current_box[i], 0), img_h - 1)
score = check_edge(image, current_box, i, is_vertical)
if score < best_score:
best_score = score
best_box = copy.deepcopy(current_box)
if score <= threshold:
break
new_boxes.append(best_box)
return new_boxes
def parse_layout_string(bbox_str):
"""Parse layout string using regular expressions"""
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
matches = re.finditer(pattern, bbox_str)
parsed_results = []
for match in matches:
coords = [float(match.group(i)) for i in range(1, 5)]
label = match.group(5).strip()
parsed_results.append((coords, label))
return parsed_results
@dataclass
class ImageDimensions:
"""Class to store image dimensions"""
original_w: int
original_h: int
padded_w: int
padded_h: int
def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
"""Map coordinates from padded image back to original image
Args:
x1, y1, x2, y2: Coordinates in padded image
dims: Image dimensions object
Returns:
tuple: (x1, y1, x2, y2) coordinates in original image
"""
try:
# Calculate padding offsets
top = (dims.padded_h - dims.original_h) // 2
left = (dims.padded_w - dims.original_w) // 2
# Map back to original coordinates
orig_x1 = max(0, x1 - left)
orig_y1 = max(0, y1 - top)
orig_x2 = min(dims.original_w, x2 - left)
orig_y2 = min(dims.original_h, y2 - top)
# Ensure we have a valid box (width and height > 0)
if orig_x2 <= orig_x1:
orig_x2 = min(orig_x1 + 1, dims.original_w)
if orig_y2 <= orig_y1:
orig_y2 = min(orig_y1 + 1, dims.original_h)
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
except Exception as e:
print(f"map_to_original_coordinates error: {str(e)}")
# Return safe coordinates
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
"""
From absolute coordinates to relevant coordinates
e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
"""
try:
x1, y1, x2, y2 = abs_coords
return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3)
except Exception as e:
print(f"map_to_relevant_coordinates error: {str(e)}")
return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
"""Process and adjust coordinates
Args:
coords: Normalized coordinates [x1, y1, x2, y2]
padded_image: Padded image
dims: Image dimensions object
previous_box: Previous box coordinates for overlap adjustment
Returns:
tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
"""
try:
# Convert normalized coordinates to absolute coordinates
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
# Ensure coordinates are within image bounds before adjustment
x1 = max(0, min(x1, dims.padded_w - 1))
y1 = max(0, min(y1, dims.padded_h - 1))
x2 = max(0, min(x2, dims.padded_w))
y2 = max(0, min(y2, dims.padded_h))
# Ensure width and height are at least 1 pixel
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
# Extend box boundaries
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
x1, y1, x2, y2 = new_boxes[0]
# Ensure coordinates are still within image bounds after adjustment
x1 = max(0, min(x1, dims.padded_w - 1))
y1 = max(0, min(y1, dims.padded_h - 1))
x2 = max(0, min(x2, dims.padded_w))
y2 = max(0, min(y2, dims.padded_h))
# Ensure width and height are at least 1 pixel after adjustment
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
# Check for overlap with previous box and adjust
if previous_box is not None:
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
y1 = prev_y2
# Ensure y1 is still valid
y1 = min(y1, dims.padded_h - 1)
# Make sure y2 is still greater than y1
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
# Update previous box
new_previous_box = [x1, y1, x2, y2]
# Map to original coordinates
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
x1, y1, x2, y2, dims
)
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
except Exception as e:
print(f"process_coordinates error: {str(e)}")
# Return safe values
orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
"""Load and prepare image with padding while maintaining aspect ratio
Args:
image: PIL image
Returns:
tuple: (padded_image, image_dimensions)
"""
try:
# Convert PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
original_h, original_w = image.shape[:2]
# Calculate padding to make square image
max_size = max(original_h, original_w)
top = (max_size - original_h) // 2
bottom = max_size - original_h - top
left = (max_size - original_w) // 2
right = max_size - original_w - left
# Apply padding
padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=(0, 0, 0))
padded_h, padded_w = padded_image.shape[:2]
dimensions = ImageDimensions(
original_w=original_w,
original_h=original_h,
padded_w=padded_w,
padded_h=padded_h
)
return padded_image, dimensions
except Exception as e:
print(f"prepare_image error: {str(e)}")
# Create a minimal valid image and dimensions
h, w = image.height, image.width
dimensions = ImageDimensions(
original_w=w,
original_h=h,
padded_w=w,
padded_h=h
)
# Return a black image of the same size
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
def setup_output_dirs(save_dir):
"""Create necessary output directories"""
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
def save_outputs(recognition_results, image_path, save_dir):
"""Save JSON and markdown outputs"""
basename = os.path.splitext(os.path.basename(image_path))[0]
# Save JSON file
json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(recognition_results, f, ensure_ascii=False, indent=2)
# Generate and save markdown file
markdown_converter = MarkdownConverter()
markdown_content = markdown_converter.convert(recognition_results)
markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
with open(markdown_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
return json_path
def crop_margin(img: Image.Image) -> Image.Image:
"""Crop margins from image"""
try:
width, height = img.size
if width == 0 or height == 0:
print("Warning: Image has zero width or height")
return img
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
if coords is None:
return img
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
# Ensure crop coordinates are within image bounds
a = max(0, a)
b = max(0, b)
w = min(w, width - a)
h = min(h, height - b)
# Only crop if we have a valid region
if w > 0 and h > 0:
return img.crop((a, b, a + w, b + h))
return img
except Exception as e:
print(f"crop_margin error: {str(e)}")
return img # Return original image on error