[init] initial commit
This commit is contained in:
442
utils/markdown_utils.py
Normal file
442
utils/markdown_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
import re
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
"""
|
||||
Example input:
|
||||
[
|
||||
{"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6},
|
||||
{"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7},
|
||||
{"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8},
|
||||
{"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9},
|
||||
{"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
def extract_table_from_html(html_string):
|
||||
"""Extract and clean table tags from HTML string"""
|
||||
try:
|
||||
table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
|
||||
tables = table_pattern.findall(html_string)
|
||||
tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
|
||||
return '\n'.join(tables)
|
||||
except Exception as e:
|
||||
print(f"extract_table_from_html error: {str(e)}")
|
||||
return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
|
||||
|
||||
|
||||
class MarkdownConverter:
|
||||
"""Convert structured recognition results to Markdown format"""
|
||||
|
||||
def __init__(self):
|
||||
# Define heading levels for different section types
|
||||
self.heading_levels = {
|
||||
'title': '#',
|
||||
'sec': '##',
|
||||
'sub_sec': '###'
|
||||
}
|
||||
|
||||
# Define which labels need special handling
|
||||
self.special_labels = {
|
||||
'tab', 'fig', 'title', 'sec', 'sub_sec',
|
||||
'list', 'formula', 'reference', 'alg'
|
||||
}
|
||||
|
||||
def try_remove_newline(self, text: str) -> str:
|
||||
try:
|
||||
# Preprocess text to handle line breaks
|
||||
text = text.strip()
|
||||
text = text.replace('-\n', '')
|
||||
|
||||
# Handle Chinese text line breaks
|
||||
def is_chinese(char):
|
||||
return '\u4e00' <= char <= '\u9fff'
|
||||
|
||||
lines = text.split('\n')
|
||||
processed_lines = []
|
||||
|
||||
# Process all lines except the last one
|
||||
for i in range(len(lines)-1):
|
||||
current_line = lines[i].strip()
|
||||
next_line = lines[i+1].strip()
|
||||
|
||||
# Always add the current line, but determine if we need a newline
|
||||
if current_line: # If current line is not empty
|
||||
if next_line: # If next line is not empty
|
||||
# For Chinese text handling
|
||||
if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
|
||||
processed_lines.append(current_line)
|
||||
else:
|
||||
processed_lines.append(current_line + ' ')
|
||||
else:
|
||||
# Next line is empty, add current line with newline
|
||||
processed_lines.append(current_line + '\n')
|
||||
else:
|
||||
# Current line is empty, add an empty line
|
||||
processed_lines.append('\n')
|
||||
|
||||
# Add the last line
|
||||
if lines and lines[-1].strip():
|
||||
processed_lines.append(lines[-1].strip())
|
||||
|
||||
text = ''.join(processed_lines)
|
||||
|
||||
return text
|
||||
except Exception as e:
|
||||
print(f"try_remove_newline error: {str(e)}")
|
||||
return text # Return original text on error
|
||||
|
||||
def _handle_text(self, text: str) -> str:
|
||||
"""
|
||||
Process regular text content, preserving paragraph structure
|
||||
"""
|
||||
try:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
|
||||
text = "$$" + text + "$$"
|
||||
elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
|
||||
text = "$" + text + "$"
|
||||
|
||||
# Process formulas in text before handling other text processing
|
||||
text = self._process_formulas_in_text(text)
|
||||
|
||||
text = self.try_remove_newline(text)
|
||||
|
||||
# Return processed text
|
||||
return text
|
||||
except Exception as e:
|
||||
print(f"_handle_text error: {str(e)}")
|
||||
return text # Return original text on error
|
||||
|
||||
def _process_formulas_in_text(self, text: str) -> str:
|
||||
"""
|
||||
Process mathematical formulas in text by iteratively finding and replacing formulas.
|
||||
- Identify inline and block formulas
|
||||
- Replace newlines within formulas with \\
|
||||
"""
|
||||
try:
|
||||
# Define formula delimiters and their corresponding patterns
|
||||
delimiters = [
|
||||
('$$', '$$'), # Block formula with $$
|
||||
('\\[', '\\]'), # Block formula with \[ \]
|
||||
('$', '$'), # Inline formula with $
|
||||
('\\(', '\\)') # Inline formula with \( \)
|
||||
]
|
||||
|
||||
# Process the text by iterating through each delimiter type
|
||||
result = text
|
||||
|
||||
for start_delim, end_delim in delimiters:
|
||||
# Create a pattern that matches from start to end delimiter
|
||||
# Using a custom approach to avoid issues with nested delimiters
|
||||
current_pos = 0
|
||||
processed_parts = []
|
||||
|
||||
while current_pos < len(result):
|
||||
# Find the next start delimiter
|
||||
start_pos = result.find(start_delim, current_pos)
|
||||
if start_pos == -1:
|
||||
# No more formulas of this type
|
||||
processed_parts.append(result[current_pos:])
|
||||
break
|
||||
|
||||
# Add text before the formula
|
||||
processed_parts.append(result[current_pos:start_pos])
|
||||
|
||||
# Find the matching end delimiter
|
||||
end_pos = result.find(end_delim, start_pos + len(start_delim))
|
||||
if end_pos == -1:
|
||||
# No matching end delimiter, treat as regular text
|
||||
processed_parts.append(result[start_pos:])
|
||||
break
|
||||
|
||||
# Extract the formula content (without delimiters)
|
||||
formula_content = result[start_pos + len(start_delim):end_pos]
|
||||
|
||||
# Process the formula content - replace newlines with \\
|
||||
processed_formula = formula_content.replace('\n', ' \\\\ ')
|
||||
|
||||
# Add the processed formula with its delimiters
|
||||
processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
|
||||
|
||||
# Move past this formula
|
||||
current_pos = end_pos + len(end_delim)
|
||||
|
||||
# Update the result with processed text
|
||||
result = ''.join(processed_parts)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"_process_formulas_in_text error: {str(e)}")
|
||||
return text # Return original text on error
|
||||
|
||||
def _remove_newline_in_heading(self, text: str) -> str:
|
||||
"""
|
||||
Remove newline in heading
|
||||
"""
|
||||
try:
|
||||
# Handle Chinese text line breaks
|
||||
def is_chinese(char):
|
||||
return '\u4e00' <= char <= '\u9fff'
|
||||
|
||||
# Check if the text contains Chinese characters
|
||||
if any(is_chinese(char) for char in text):
|
||||
return text.replace('\n', '')
|
||||
else:
|
||||
return text.replace('\n', ' ')
|
||||
|
||||
except Exception as e:
|
||||
print(f"_remove_newline_in_heading error: {str(e)}")
|
||||
return text
|
||||
|
||||
def _handle_heading(self, text: str, label: str) -> str:
|
||||
"""
|
||||
Convert section headings to appropriate markdown format
|
||||
"""
|
||||
try:
|
||||
level = self.heading_levels.get(label, '#')
|
||||
text = text.strip()
|
||||
text = self._remove_newline_in_heading(text)
|
||||
text = self._handle_text(text)
|
||||
return f"{level} {text}\n\n"
|
||||
except Exception as e:
|
||||
print(f"_handle_heading error: {str(e)}")
|
||||
return f"# Error processing heading: {text}\n\n"
|
||||
|
||||
def _handle_list_item(self, text: str) -> str:
|
||||
"""
|
||||
Convert list items to markdown list format
|
||||
"""
|
||||
try:
|
||||
return f"- {text.strip()}\n"
|
||||
except Exception as e:
|
||||
print(f"_handle_list_item error: {str(e)}")
|
||||
return f"- Error processing list item: {text}\n"
|
||||
|
||||
def _handle_figure(self, text: str, section_count: int) -> str:
|
||||
"""
|
||||
Convert base64 encoded image to markdown image syntax
|
||||
"""
|
||||
try:
|
||||
# Determine image format (assuming PNG if not specified)
|
||||
img_format = "png"
|
||||
if text.startswith("data:image/"):
|
||||
# Extract format from data URI
|
||||
img_format = text.split(";")[0].split("/")[1]
|
||||
elif ";" in text and "," in text:
|
||||
# Already in data URI format
|
||||
return f"\n\n"
|
||||
else:
|
||||
# Raw base64, convert to data URI
|
||||
data_uri = f"data:image/{img_format};base64,{text}"
|
||||
return f"\n\n"
|
||||
except Exception as e:
|
||||
print(f"_handle_figure error: {str(e)}")
|
||||
return f"*[Error processing figure: {str(e)}]*\n\n"
|
||||
|
||||
def _handle_table(self, text: str) -> str:
|
||||
"""
|
||||
Convert table content to markdown format
|
||||
"""
|
||||
try:
|
||||
markdown_content = []
|
||||
if '<table' in text.lower() or '<tr' in text.lower():
|
||||
markdown_table = extract_table_from_html(text)
|
||||
markdown_content.append(markdown_table + "\n")
|
||||
else:
|
||||
table_lines = text.split('\n')
|
||||
if table_lines:
|
||||
col_count = len(table_lines[0].split()) if table_lines[0] else 1
|
||||
header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
|
||||
markdown_content.append(header)
|
||||
markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |')
|
||||
for line in table_lines[1:]:
|
||||
cells = line.split()
|
||||
while len(cells) < col_count:
|
||||
cells.append('')
|
||||
markdown_content.append('| ' + ' | '.join(cells) + ' |')
|
||||
return '\n'.join(markdown_content) + '\n\n'
|
||||
except Exception as e:
|
||||
print(f"_handle_table error: {str(e)}")
|
||||
return f"*[Error processing table: {str(e)}]*\n\n"
|
||||
|
||||
def _handle_algorithm(self, text: str) -> str:
|
||||
"""
|
||||
Process algorithm blocks with proper formatting
|
||||
"""
|
||||
try:
|
||||
# Remove algorithm environment tags if present
|
||||
text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
|
||||
text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
|
||||
text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
|
||||
|
||||
# Process the algorithm text
|
||||
lines = text.strip().split('\n')
|
||||
|
||||
# Check if there's a caption or label
|
||||
caption = ""
|
||||
algorithm_text = []
|
||||
|
||||
for line in lines:
|
||||
if '\\caption' in line:
|
||||
# Extract caption text
|
||||
caption_match = re.search(r'\\caption\{(.*?)\}', line)
|
||||
if caption_match:
|
||||
caption = f"**{caption_match.group(1)}**\n\n"
|
||||
continue
|
||||
elif '\\label' in line:
|
||||
continue # Skip label lines
|
||||
else:
|
||||
algorithm_text.append(line)
|
||||
|
||||
# Join the algorithm text and wrap in code block
|
||||
formatted_text = '\n'.join(algorithm_text)
|
||||
|
||||
# Return the formatted algorithm with caption
|
||||
return f"{caption}```\n{formatted_text}\n```\n\n"
|
||||
except Exception as e:
|
||||
print(f"_handle_algorithm error: {str(e)}")
|
||||
return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
|
||||
|
||||
def _handle_formula(self, text: str) -> str:
|
||||
"""
|
||||
Handle formula-specific content
|
||||
"""
|
||||
try:
|
||||
# Process the formula content
|
||||
processed_text = self._process_formulas_in_text(text)
|
||||
|
||||
# For formula blocks, ensure they're properly formatted in markdown
|
||||
if '$$' not in processed_text and '\\[' not in processed_text:
|
||||
# If no block formula delimiters are present, wrap in $$ for block formula
|
||||
processed_text = f'$${processed_text}$$'
|
||||
|
||||
return f"{processed_text}\n\n"
|
||||
except Exception as e:
|
||||
print(f"_handle_formula error: {str(e)}")
|
||||
return f"*[Error processing formula: {str(e)}]*\n\n"
|
||||
|
||||
def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Convert recognition results to markdown format
|
||||
"""
|
||||
try:
|
||||
markdown_content = []
|
||||
|
||||
for section_count, result in enumerate(recognition_results):
|
||||
try:
|
||||
label = result.get('label', '')
|
||||
text = result.get('text', '').strip()
|
||||
|
||||
# Skip empty text
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# Handle different content types
|
||||
if label in {'title', 'sec', 'sub_sec'}:
|
||||
markdown_content.append(self._handle_heading(text, label))
|
||||
elif label == 'list':
|
||||
markdown_content.append(self._handle_list_item(text))
|
||||
elif label == 'fig':
|
||||
markdown_content.append(self._handle_figure(text, section_count))
|
||||
elif label == 'tab':
|
||||
markdown_content.append(self._handle_table(text))
|
||||
elif label == 'alg':
|
||||
markdown_content.append(self._handle_algorithm(text))
|
||||
elif label == 'formula':
|
||||
markdown_content.append(self._handle_formula(text))
|
||||
elif label not in self.special_labels:
|
||||
# Handle regular text (paragraphs, etc.)
|
||||
processed_text = self._handle_text(text)
|
||||
markdown_content.append(f"{processed_text}\n\n")
|
||||
except Exception as e:
|
||||
print(f"Error processing item {section_count}: {str(e)}")
|
||||
# Add a placeholder for the failed item
|
||||
markdown_content.append(f"*[Error processing content]*\n\n")
|
||||
|
||||
# Join all content and apply post-processing
|
||||
result = ''.join(markdown_content)
|
||||
return self._post_process(result)
|
||||
except Exception as e:
|
||||
print(f"convert error: {str(e)}")
|
||||
return f"Error generating markdown content: {str(e)}"
|
||||
|
||||
def _post_process(self, markdown_content: str) -> str:
|
||||
"""
|
||||
Apply post-processing fixes to the generated markdown content
|
||||
"""
|
||||
try:
|
||||
# Handle author information
|
||||
author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
|
||||
|
||||
def process_author_match(match):
|
||||
# Extract author content
|
||||
author_content = match.group(1)
|
||||
# Process the author content
|
||||
return self._handle_text(author_content)
|
||||
|
||||
# Replace \author{...} with processed content
|
||||
markdown_content = author_pattern.sub(process_author_match, markdown_content)
|
||||
|
||||
# Handle special case where author is inside math environment
|
||||
math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL)
|
||||
match = math_author_pattern.search(markdown_content)
|
||||
if match:
|
||||
# Extract the author command
|
||||
author_cmd = match.group(1)
|
||||
# Extract content from author command
|
||||
author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
|
||||
if author_content_match:
|
||||
# Get author content and process it
|
||||
author_content = author_content_match.group(1)
|
||||
processed_content = self._handle_text(author_content)
|
||||
# Replace the entire $\author{...}$ block with processed content
|
||||
markdown_content = markdown_content.replace(match.group(0), processed_content)
|
||||
|
||||
# Replace LaTeX abstract environment with plain text
|
||||
markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}',
|
||||
r'**Abstract** \1',
|
||||
markdown_content,
|
||||
flags=re.DOTALL)
|
||||
|
||||
# Replace standalone \begin{abstract} (without matching end)
|
||||
markdown_content = re.sub(r'\\begin\{abstract\}',
|
||||
r'**Abstract**',
|
||||
markdown_content)
|
||||
|
||||
# Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
|
||||
markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
|
||||
r'\\tag{\1}',
|
||||
markdown_content)
|
||||
|
||||
# Find the starting tag of the formula
|
||||
markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\")
|
||||
|
||||
# Find the ending tag of the formula (ensure this is the only ending tag)
|
||||
markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$")
|
||||
|
||||
# Fix other common LaTeX issues
|
||||
replacements = [
|
||||
# Fix spacing issues in subscripts and superscripts
|
||||
(r'_ {', r'_{'),
|
||||
(r'^ {', r'^{'),
|
||||
|
||||
# Fix potential issues with multiple consecutive newlines
|
||||
(r'\n{3,}', r'\n\n')
|
||||
]
|
||||
|
||||
for old, new in replacements:
|
||||
markdown_content = re.sub(old, new, markdown_content)
|
||||
|
||||
return markdown_content
|
||||
except Exception as e:
|
||||
print(f"_post_process error: {str(e)}")
|
||||
return markdown_content # Return original content if post-processing fails
|
||||
477
utils/model.py
Normal file
477
utils/model.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Copyright (c) 2022-present NAVER Corp.
|
||||
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
||||
MIT License
|
||||
This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
|
||||
The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
|
||||
This modified file is released under the same license.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from timm.models.swin_transformer import SwinTransformer
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
MBartConfig,
|
||||
MBartForCausalLM,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformers.file_utils import ModelOutput
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
|
||||
|
||||
class SwinEncoder(nn.Module):
|
||||
r"""
|
||||
Encoder based on SwinTransformer
|
||||
Set the initial weights and configuration with a pretrained SwinTransformer and then
|
||||
modify the detailed configurations
|
||||
|
||||
Args:
|
||||
input_size: Input image size (width, height)
|
||||
align_long_axis: Whether to rotate image if height is greater than width
|
||||
window_size: Window size(=patch size) of SwinTransformer
|
||||
encoder_layer: Number of layers of SwinTransformer encoder
|
||||
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
|
||||
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
align_long_axis: bool = False,
|
||||
window_size: int = 7,
|
||||
encoder_layer: List[int] = [2, 2, 14, 2],
|
||||
patch_size: int = [4, 4],
|
||||
embed_dim: int = 128,
|
||||
num_heads: List[int] = [4, 8, 16, 32],
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(input_size, int):
|
||||
input_size = [input_size, input_size]
|
||||
self.input_size = input_size
|
||||
self.align_long_axis = align_long_axis
|
||||
self.window_size = window_size
|
||||
self.encoder_layer = encoder_layer
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.model = SwinTransformer(
|
||||
img_size=self.input_size,
|
||||
depths=self.encoder_layer,
|
||||
window_size=self.window_size,
|
||||
patch_size=self.patch_size,
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_classes=0,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch_size, num_channels, height, width)
|
||||
"""
|
||||
x = self.model.patch_embed(x)
|
||||
x = self.model.pos_drop(x)
|
||||
x = self.model.layers(x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def _set_dtype(self, dtype):
|
||||
self._dtype = dtype
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(dtype=self._dtype))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class BARTDecoder(nn.Module):
|
||||
"""
|
||||
Decoder based on Multilingual BART
|
||||
Set the initial weights and configuration with a pretrained multilingual BART model,
|
||||
and modify the detailed configurations as a Donut decoder
|
||||
|
||||
Args:
|
||||
decoder_layer:
|
||||
Number of layers of BARTDecoder
|
||||
max_position_embeddings:
|
||||
The maximum sequence length to be trained
|
||||
name_or_path:
|
||||
Name of a pretrained model name either registered in huggingface.co. or saved in local,
|
||||
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
decoder_layer: int,
|
||||
max_position_embeddings: int,
|
||||
hidden_dimension: int = 1024,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layer = decoder_layer
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_dimension = hidden_dimension
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.model = MBartForCausalLM(
|
||||
config=MBartConfig(
|
||||
tie_word_embeddings=True,
|
||||
is_decoder=True,
|
||||
is_encoder_decoder=False,
|
||||
add_cross_attention=True,
|
||||
decoder_layers=self.decoder_layer,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
vocab_size=len(self.tokenizer),
|
||||
scale_embedding=True,
|
||||
add_final_layer_norm=True,
|
||||
d_model=self.hidden_dimension,
|
||||
)
|
||||
)
|
||||
# self.model.config.is_encoder_decoder = True # to get cross-attention
|
||||
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
|
||||
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
|
||||
|
||||
def add_special_tokens(self, list_of_tokens: List[str]):
|
||||
"""
|
||||
Add special tokens to tokenizer and resize the token embeddings
|
||||
"""
|
||||
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
|
||||
if newly_added_num > 0:
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
def add_tokens(self, list_of_tokens: List[str]):
|
||||
"""
|
||||
Add special tokens to tokenizer and resize the token embeddings
|
||||
"""
|
||||
newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
|
||||
if newly_added_num > 0:
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
def prepare_inputs_for_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
encoder_outputs: torch.Tensor,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
use_cache: bool = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_ids: (batch_size, sequence_length)
|
||||
|
||||
Returns:
|
||||
input_ids: (batch_size, sequence_length)
|
||||
attention_mask: (batch_size, sequence_length)
|
||||
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
|
||||
"""
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
||||
past = past or past_key_values
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
output = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
"encoder_hidden_states": encoder_outputs.last_hidden_state,
|
||||
}
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = None,
|
||||
output_attentions: Optional[torch.Tensor] = None,
|
||||
output_hidden_states: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Resize position embeddings
|
||||
Truncate if sequence length of MBart backbone is greater than given max_length,
|
||||
else interpolate to max_length
|
||||
"""
|
||||
if weight.shape[0] > max_length:
|
||||
weight = weight[:max_length, ...]
|
||||
else:
|
||||
weight = (
|
||||
F.interpolate(
|
||||
weight.permute(1, 0).unsqueeze(0),
|
||||
size=max_length,
|
||||
mode="linear",
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze(0)
|
||||
.permute(1, 0)
|
||||
)
|
||||
return weight
|
||||
|
||||
|
||||
class DonutConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder_layer: int = 10,
|
||||
max_position_embeddings: int = None,
|
||||
max_length: int = 4096,
|
||||
hidden_dimension: int = 1024,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layer = decoder_layer
|
||||
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
|
||||
self.max_length = max_length
|
||||
self.hidden_dimension = hidden_dimension
|
||||
|
||||
|
||||
class RunningVarTorch:
|
||||
def __init__(self, L=15, norm=False):
|
||||
self.values = None
|
||||
self.L = L
|
||||
self.norm = norm
|
||||
|
||||
def push(self, x: torch.Tensor):
|
||||
assert x.dim() == 1
|
||||
if self.values is None:
|
||||
self.values = x[:, None]
|
||||
elif self.values.shape[1] < self.L:
|
||||
self.values = torch.cat((self.values, x[:, None]), 1)
|
||||
else:
|
||||
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
||||
|
||||
def variance(self):
|
||||
if self.values is None:
|
||||
return
|
||||
if self.norm:
|
||||
return torch.var(self.values, 1) / self.values.shape[1]
|
||||
else:
|
||||
return torch.var(self.values, 1)
|
||||
|
||||
|
||||
class StoppingCriteriaScores(StoppingCriteria):
|
||||
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
self.vars = RunningVarTorch(norm=True)
|
||||
self.varvars = RunningVarTorch(L=window_size)
|
||||
self.stop_inds = defaultdict(int)
|
||||
self.stopped = defaultdict(bool)
|
||||
self.size = 0
|
||||
self.window_size = window_size
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
last_scores = scores[-1]
|
||||
self.vars.push(last_scores.max(1)[0].float().cpu())
|
||||
self.varvars.push(self.vars.variance())
|
||||
self.size += 1
|
||||
if self.size < self.window_size:
|
||||
return False
|
||||
|
||||
varvar = self.varvars.variance()
|
||||
for b in range(len(last_scores)):
|
||||
if varvar[b] < self.threshold:
|
||||
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
||||
self.stopped[b] = self.stop_inds[b] >= self.size
|
||||
else:
|
||||
self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
|
||||
else:
|
||||
self.stop_inds[b] = 0
|
||||
self.stopped[b] = False
|
||||
return all(self.stopped.values()) and len(self.stopped) > 0
|
||||
|
||||
|
||||
def batch(l, b=15):
|
||||
subs = []
|
||||
for i in range(len(l) - b):
|
||||
subs.append(l[i : i + b])
|
||||
return subs
|
||||
|
||||
|
||||
def subdiv(l, b=10):
|
||||
subs = []
|
||||
for i in range(len(l) - b):
|
||||
subs.append(l[: i + b])
|
||||
return subs
|
||||
|
||||
|
||||
class DonutModel(PreTrainedModel):
|
||||
config_class = DonutConfig
|
||||
base_model_prefix = "donut"
|
||||
|
||||
def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.vpm = vision_tower
|
||||
|
||||
# build language model
|
||||
self.llm = BARTDecoder(
|
||||
tokenizer=tokenizer,
|
||||
decoder_layer=self.config.decoder_layer,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
hidden_dimension=self.config.hidden_dimension,
|
||||
)
|
||||
self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
|
||||
|
||||
def get_input_embeddings(self, tensor):
|
||||
return self.llm.model.get_input_embeddings()(tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
):
|
||||
image_tensors = inputs["pixel_values"]
|
||||
input_ids = inputs["input_ids"].contiguous()
|
||||
attention_mask = inputs["attention_mask"]
|
||||
labels = inputs["labels"].contiguous()
|
||||
|
||||
encoder_outputs = self.vpm(
|
||||
image_tensors,
|
||||
text_embedding=self.llm.model.get_input_embeddings()(input_ids),
|
||||
)
|
||||
|
||||
decoder_outputs = self.llm(
|
||||
input_ids=input_ids,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_hidden_states_during_inference(
|
||||
self,
|
||||
prompt_ids: torch.Tensor,
|
||||
image: Image.Image = None,
|
||||
image_tensors: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if image_tensors is None:
|
||||
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
||||
|
||||
if self.device.type != "mps":
|
||||
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
||||
|
||||
image_tensors = image_tensors.to(self.device)
|
||||
prompt_ids = prompt_ids.to(self.device)
|
||||
all_hidden_states = self.vpm.forward_features(
|
||||
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
|
||||
)
|
||||
return all_hidden_states
|
||||
|
||||
def get_attn_weights_during_inference(
|
||||
self,
|
||||
prompt_ids: torch.Tensor,
|
||||
image: Image.Image = None,
|
||||
image_tensors: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if image_tensors is None:
|
||||
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
||||
|
||||
if self.device.type != "mps":
|
||||
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
||||
|
||||
image_tensors = image_tensors.to(self.device)
|
||||
prompt_ids = prompt_ids.to(self.device)
|
||||
last_attn_score = self.vpm.get_last_layer_cross_attn_score(
|
||||
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
|
||||
)
|
||||
return last_attn_score
|
||||
|
||||
def inference(
|
||||
self,
|
||||
prompt_ids: torch.Tensor,
|
||||
image: Image.Image = None,
|
||||
image_tensors: Optional[torch.Tensor] = None,
|
||||
return_attentions: bool = False,
|
||||
early_stopping: bool = True,
|
||||
):
|
||||
"""
|
||||
Generate a token sequence in an auto-regressive manner.
|
||||
|
||||
Args:
|
||||
image: input document image (PIL.Image)
|
||||
image_tensors: (1, num_channels, height, width)
|
||||
convert prompt to tensor if image_tensor is not fed
|
||||
"""
|
||||
output = {
|
||||
"predictions": list(),
|
||||
"sequences": list(),
|
||||
"repeats": list(),
|
||||
"repetitions": list(),
|
||||
}
|
||||
if image is None and image_tensors is None:
|
||||
logging.warn("Image not found")
|
||||
return output
|
||||
|
||||
if image_tensors is None:
|
||||
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
||||
|
||||
if self.device.type != "mps":
|
||||
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
||||
|
||||
image_tensors = image_tensors.to(self.device)
|
||||
prompt_ids = prompt_ids.to(self.device)
|
||||
last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
|
||||
|
||||
encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
|
||||
if len(encoder_outputs.last_hidden_state.size()) == 1:
|
||||
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
|
||||
|
||||
# get decoder output
|
||||
decoder_output = self.llm.model.generate(
|
||||
input_ids=prompt_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
min_length=1,
|
||||
max_length=self.config.max_length,
|
||||
pad_token_id=self.llm.tokenizer.pad_token_id,
|
||||
eos_token_id=self.llm.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
output_attentions=return_attentions,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
|
||||
)
|
||||
|
||||
output["repetitions"] = decoder_output.sequences.clone()
|
||||
output["sequences"] = decoder_output.sequences.clone()
|
||||
output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
|
||||
|
||||
output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
|
||||
return output
|
||||
64
utils/processor.py
Normal file
64
utils/processor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import ImageOps
|
||||
|
||||
from utils.utils import *
|
||||
|
||||
|
||||
class DolphinProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
dp_config,
|
||||
tokenizer,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
transform_args = kwargs.get("transform_args", {})
|
||||
self.max_length = transform_args.get("max_length", 2048)
|
||||
self.input_size = transform_args.get("input_size", [896, 896]) # height, width
|
||||
if isinstance(self.input_size, int):
|
||||
self.input_size = [self.input_size, self.input_size]
|
||||
|
||||
try:
|
||||
self.answer_start_token = self.tokenizer._prompt_end_token
|
||||
except AttributeError as err:
|
||||
print('No answer_start_token found, use "" instead')
|
||||
self.answer_start_token = ""
|
||||
|
||||
self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
|
||||
self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
|
||||
|
||||
def process_prompt_for_inference(self, prompt):
|
||||
prompt = prompt.replace("<image>\n", "")
|
||||
if not prompt.startswith("<s>"):
|
||||
prompt = "<s>" + prompt
|
||||
message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
|
||||
ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
|
||||
return ids.unsqueeze(0)
|
||||
|
||||
def process_image_for_inference(self, image, return_img_size=False):
|
||||
image = resize(image, min(self.input_size))
|
||||
|
||||
image.thumbnail((self.input_size[1], self.input_size[0]))
|
||||
origin_w, origin_h = image.size
|
||||
|
||||
delta_width = self.input_size[1] - image.width
|
||||
delta_height = self.input_size[0] - image.height
|
||||
pad_width = delta_width // 2
|
||||
pad_height = delta_height // 2
|
||||
padding = (
|
||||
pad_width,
|
||||
pad_height,
|
||||
delta_width - pad_width,
|
||||
delta_height - pad_height,
|
||||
)
|
||||
image = ImageOps.expand(image, padding)
|
||||
if return_img_size:
|
||||
return test_transform(image).unsqueeze(0), (origin_w, origin_h)
|
||||
return test_transform(image).unsqueeze(0)
|
||||
367
utils/utils.py
Normal file
367
utils/utils.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import albumentations as alb
|
||||
import cv2
|
||||
import numpy as np
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms.functional import resize
|
||||
|
||||
from utils.markdown_utils import MarkdownConverter
|
||||
|
||||
|
||||
def alb_wrapper(transform):
|
||||
def f(im):
|
||||
return transform(image=np.asarray(im))["image"]
|
||||
|
||||
return f
|
||||
|
||||
|
||||
test_transform = alb_wrapper(
|
||||
alb.Compose(
|
||||
[
|
||||
alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
||||
ToTensorV2(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
|
||||
# print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||
if x1 < 0 or y1 < 0:
|
||||
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||
if not abs_coord:
|
||||
if x2 > 1 or y2 > 1:
|
||||
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||
elif image_size is not None: # has image size
|
||||
if x2 > image_size[0] or y2 > image_size[1]:
|
||||
return False, f"[{x1}, {y1}, {x2}, {y2}]"
|
||||
return True, None
|
||||
|
||||
|
||||
def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
|
||||
"""
|
||||
Image: cv2.image object, or Path
|
||||
Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)
|
||||
img_h, img_w = image.shape[:2]
|
||||
new_boxes = []
|
||||
for box in boxes:
|
||||
best_box = copy.deepcopy(box)
|
||||
|
||||
def check_edge(img, current_box, i, is_vertical):
|
||||
edge = current_box[i]
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
|
||||
if is_vertical:
|
||||
line = binary[current_box[1] : current_box[3] + 1, edge]
|
||||
else:
|
||||
line = binary[edge, current_box[0] : current_box[2] + 1]
|
||||
|
||||
transitions = np.abs(np.diff(line))
|
||||
return np.sum(transitions) / len(transitions)
|
||||
|
||||
# Only widen the box
|
||||
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
|
||||
|
||||
current_box = copy.deepcopy(box)
|
||||
# make sure the box is within the image
|
||||
current_box[0] = min(max(current_box[0], 0), img_w - 1)
|
||||
current_box[1] = min(max(current_box[1], 0), img_h - 1)
|
||||
current_box[2] = min(max(current_box[2], 0), img_w - 1)
|
||||
current_box[3] = min(max(current_box[3], 0), img_h - 1)
|
||||
|
||||
for i, direction, is_vertical in edges:
|
||||
best_score = check_edge(image, current_box, i, is_vertical)
|
||||
if best_score <= threshold:
|
||||
continue
|
||||
for step in range(max_pixels):
|
||||
current_box[i] += direction
|
||||
if i == 0 or i == 2:
|
||||
current_box[i] = min(max(current_box[i], 0), img_w - 1)
|
||||
else:
|
||||
current_box[i] = min(max(current_box[i], 0), img_h - 1)
|
||||
score = check_edge(image, current_box, i, is_vertical)
|
||||
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_box = copy.deepcopy(current_box)
|
||||
|
||||
if score <= threshold:
|
||||
break
|
||||
new_boxes.append(best_box)
|
||||
|
||||
return new_boxes
|
||||
|
||||
|
||||
def parse_layout_string(bbox_str):
|
||||
"""Parse layout string using regular expressions"""
|
||||
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
|
||||
matches = re.finditer(pattern, bbox_str)
|
||||
|
||||
parsed_results = []
|
||||
for match in matches:
|
||||
coords = [float(match.group(i)) for i in range(1, 5)]
|
||||
label = match.group(5).strip()
|
||||
parsed_results.append((coords, label))
|
||||
|
||||
return parsed_results
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageDimensions:
|
||||
"""Class to store image dimensions"""
|
||||
original_w: int
|
||||
original_h: int
|
||||
padded_w: int
|
||||
padded_h: int
|
||||
|
||||
|
||||
def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
|
||||
"""Map coordinates from padded image back to original image
|
||||
|
||||
Args:
|
||||
x1, y1, x2, y2: Coordinates in padded image
|
||||
dims: Image dimensions object
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2) coordinates in original image
|
||||
"""
|
||||
try:
|
||||
# Calculate padding offsets
|
||||
top = (dims.padded_h - dims.original_h) // 2
|
||||
left = (dims.padded_w - dims.original_w) // 2
|
||||
|
||||
# Map back to original coordinates
|
||||
orig_x1 = max(0, x1 - left)
|
||||
orig_y1 = max(0, y1 - top)
|
||||
orig_x2 = min(dims.original_w, x2 - left)
|
||||
orig_y2 = min(dims.original_h, y2 - top)
|
||||
|
||||
# Ensure we have a valid box (width and height > 0)
|
||||
if orig_x2 <= orig_x1:
|
||||
orig_x2 = min(orig_x1 + 1, dims.original_w)
|
||||
if orig_y2 <= orig_y1:
|
||||
orig_y2 = min(orig_y1 + 1, dims.original_h)
|
||||
|
||||
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
|
||||
except Exception as e:
|
||||
print(f"map_to_original_coordinates error: {str(e)}")
|
||||
# Return safe coordinates
|
||||
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||
|
||||
|
||||
def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
|
||||
"""
|
||||
From absolute coordinates to relevant coordinates
|
||||
e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
|
||||
"""
|
||||
try:
|
||||
x1, y1, x2, y2 = abs_coords
|
||||
return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3)
|
||||
except Exception as e:
|
||||
print(f"map_to_relevant_coordinates error: {str(e)}")
|
||||
return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates
|
||||
|
||||
|
||||
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
|
||||
"""Process and adjust coordinates
|
||||
|
||||
Args:
|
||||
coords: Normalized coordinates [x1, y1, x2, y2]
|
||||
padded_image: Padded image
|
||||
dims: Image dimensions object
|
||||
previous_box: Previous box coordinates for overlap adjustment
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
|
||||
"""
|
||||
try:
|
||||
# Convert normalized coordinates to absolute coordinates
|
||||
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
|
||||
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
|
||||
|
||||
# Ensure coordinates are within image bounds before adjustment
|
||||
x1 = max(0, min(x1, dims.padded_w - 1))
|
||||
y1 = max(0, min(y1, dims.padded_h - 1))
|
||||
x2 = max(0, min(x2, dims.padded_w))
|
||||
y2 = max(0, min(y2, dims.padded_h))
|
||||
|
||||
# Ensure width and height are at least 1 pixel
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
|
||||
# Extend box boundaries
|
||||
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
|
||||
x1, y1, x2, y2 = new_boxes[0]
|
||||
|
||||
# Ensure coordinates are still within image bounds after adjustment
|
||||
x1 = max(0, min(x1, dims.padded_w - 1))
|
||||
y1 = max(0, min(y1, dims.padded_h - 1))
|
||||
x2 = max(0, min(x2, dims.padded_w))
|
||||
y2 = max(0, min(y2, dims.padded_h))
|
||||
|
||||
# Ensure width and height are at least 1 pixel after adjustment
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
|
||||
# Check for overlap with previous box and adjust
|
||||
if previous_box is not None:
|
||||
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
|
||||
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
|
||||
y1 = prev_y2
|
||||
# Ensure y1 is still valid
|
||||
y1 = min(y1, dims.padded_h - 1)
|
||||
# Make sure y2 is still greater than y1
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
|
||||
# Update previous box
|
||||
new_previous_box = [x1, y1, x2, y2]
|
||||
|
||||
# Map to original coordinates
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims
|
||||
)
|
||||
|
||||
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
|
||||
except Exception as e:
|
||||
print(f"process_coordinates error: {str(e)}")
|
||||
# Return safe values
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
|
||||
|
||||
|
||||
def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
|
||||
"""Load and prepare image with padding while maintaining aspect ratio
|
||||
|
||||
Args:
|
||||
image: PIL image
|
||||
|
||||
Returns:
|
||||
tuple: (padded_image, image_dimensions)
|
||||
"""
|
||||
try:
|
||||
# Convert PIL image to OpenCV format
|
||||
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
original_h, original_w = image.shape[:2]
|
||||
|
||||
# Calculate padding to make square image
|
||||
max_size = max(original_h, original_w)
|
||||
top = (max_size - original_h) // 2
|
||||
bottom = max_size - original_h - top
|
||||
left = (max_size - original_w) // 2
|
||||
right = max_size - original_w - left
|
||||
|
||||
# Apply padding
|
||||
padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
|
||||
cv2.BORDER_CONSTANT, value=(0, 0, 0))
|
||||
|
||||
padded_h, padded_w = padded_image.shape[:2]
|
||||
|
||||
dimensions = ImageDimensions(
|
||||
original_w=original_w,
|
||||
original_h=original_h,
|
||||
padded_w=padded_w,
|
||||
padded_h=padded_h
|
||||
)
|
||||
|
||||
return padded_image, dimensions
|
||||
except Exception as e:
|
||||
print(f"prepare_image error: {str(e)}")
|
||||
# Create a minimal valid image and dimensions
|
||||
h, w = image.height, image.width
|
||||
dimensions = ImageDimensions(
|
||||
original_w=w,
|
||||
original_h=h,
|
||||
padded_w=w,
|
||||
padded_h=h
|
||||
)
|
||||
# Return a black image of the same size
|
||||
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
|
||||
|
||||
|
||||
|
||||
|
||||
def setup_output_dirs(save_dir):
|
||||
"""Create necessary output directories"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
|
||||
os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
|
||||
|
||||
|
||||
def save_outputs(recognition_results, image_path, save_dir):
|
||||
"""Save JSON and markdown outputs"""
|
||||
basename = os.path.splitext(os.path.basename(image_path))[0]
|
||||
|
||||
# Save JSON file
|
||||
json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(recognition_results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# Generate and save markdown file
|
||||
markdown_converter = MarkdownConverter()
|
||||
markdown_content = markdown_converter.convert(recognition_results)
|
||||
markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
|
||||
with open(markdown_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
return json_path
|
||||
|
||||
|
||||
def crop_margin(img: Image.Image) -> Image.Image:
|
||||
"""Crop margins from image"""
|
||||
try:
|
||||
width, height = img.size
|
||||
if width == 0 or height == 0:
|
||||
print("Warning: Image has zero width or height")
|
||||
return img
|
||||
|
||||
data = np.array(img.convert("L"))
|
||||
data = data.astype(np.uint8)
|
||||
max_val = data.max()
|
||||
min_val = data.min()
|
||||
if max_val == min_val:
|
||||
return img
|
||||
data = (data - min_val) / (max_val - min_val) * 255
|
||||
gray = 255 * (data < 200).astype(np.uint8)
|
||||
|
||||
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
||||
if coords is None:
|
||||
return img
|
||||
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
||||
|
||||
# Ensure crop coordinates are within image bounds
|
||||
a = max(0, a)
|
||||
b = max(0, b)
|
||||
w = min(w, width - a)
|
||||
h = min(h, height - b)
|
||||
|
||||
# Only crop if we have a valid region
|
||||
if w > 0 and h > 0:
|
||||
return img.crop((a, b, a + w, b + h))
|
||||
return img
|
||||
except Exception as e:
|
||||
print(f"crop_margin error: {str(e)}")
|
||||
return img # Return original image on error
|
||||
Reference in New Issue
Block a user