478 lines
16 KiB
Python
478 lines
16 KiB
Python
"""
|
|
Copyright (c) 2022-present NAVER Corp.
|
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
|
MIT License
|
|
This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
|
|
The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
|
|
This modified file is released under the same license.
|
|
"""
|
|
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
from timm.models.swin_transformer import SwinTransformer
|
|
from torch import nn
|
|
from transformers import (
|
|
MBartConfig,
|
|
MBartForCausalLM,
|
|
StoppingCriteria,
|
|
StoppingCriteriaList,
|
|
)
|
|
from transformers.file_utils import ModelOutput
|
|
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
|
|
|
|
|
class SwinEncoder(nn.Module):
|
|
r"""
|
|
Encoder based on SwinTransformer
|
|
Set the initial weights and configuration with a pretrained SwinTransformer and then
|
|
modify the detailed configurations
|
|
|
|
Args:
|
|
input_size: Input image size (width, height)
|
|
align_long_axis: Whether to rotate image if height is greater than width
|
|
window_size: Window size(=patch size) of SwinTransformer
|
|
encoder_layer: Number of layers of SwinTransformer encoder
|
|
name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
|
|
otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
align_long_axis: bool = False,
|
|
window_size: int = 7,
|
|
encoder_layer: List[int] = [2, 2, 14, 2],
|
|
patch_size: int = [4, 4],
|
|
embed_dim: int = 128,
|
|
num_heads: List[int] = [4, 8, 16, 32],
|
|
):
|
|
super().__init__()
|
|
if isinstance(input_size, int):
|
|
input_size = [input_size, input_size]
|
|
self.input_size = input_size
|
|
self.align_long_axis = align_long_axis
|
|
self.window_size = window_size
|
|
self.encoder_layer = encoder_layer
|
|
self.patch_size = patch_size
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
|
|
self.model = SwinTransformer(
|
|
img_size=self.input_size,
|
|
depths=self.encoder_layer,
|
|
window_size=self.window_size,
|
|
patch_size=self.patch_size,
|
|
embed_dim=self.embed_dim,
|
|
num_heads=self.num_heads,
|
|
num_classes=0,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: (batch_size, num_channels, height, width)
|
|
"""
|
|
x = self.model.patch_embed(x)
|
|
x = self.model.pos_drop(x)
|
|
x = self.model.layers(x)
|
|
return x
|
|
|
|
|
|
class LayerNorm(nn.LayerNorm):
|
|
"""Subclass torch's LayerNorm to handle fp16."""
|
|
|
|
def _set_dtype(self, dtype):
|
|
self._dtype = dtype
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
orig_type = x.dtype
|
|
ret = super().forward(x.type(dtype=self._dtype))
|
|
return ret.type(orig_type)
|
|
|
|
|
|
class BARTDecoder(nn.Module):
|
|
"""
|
|
Decoder based on Multilingual BART
|
|
Set the initial weights and configuration with a pretrained multilingual BART model,
|
|
and modify the detailed configurations as a Donut decoder
|
|
|
|
Args:
|
|
decoder_layer:
|
|
Number of layers of BARTDecoder
|
|
max_position_embeddings:
|
|
The maximum sequence length to be trained
|
|
name_or_path:
|
|
Name of a pretrained model name either registered in huggingface.co. or saved in local,
|
|
otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer,
|
|
decoder_layer: int,
|
|
max_position_embeddings: int,
|
|
hidden_dimension: int = 1024,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.decoder_layer = decoder_layer
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_dimension = hidden_dimension
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.model = MBartForCausalLM(
|
|
config=MBartConfig(
|
|
tie_word_embeddings=True,
|
|
is_decoder=True,
|
|
is_encoder_decoder=False,
|
|
add_cross_attention=True,
|
|
decoder_layers=self.decoder_layer,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
vocab_size=len(self.tokenizer),
|
|
scale_embedding=True,
|
|
add_final_layer_norm=True,
|
|
d_model=self.hidden_dimension,
|
|
)
|
|
)
|
|
# self.model.config.is_encoder_decoder = True # to get cross-attention
|
|
self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
|
|
self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
|
|
|
|
def add_special_tokens(self, list_of_tokens: List[str]):
|
|
"""
|
|
Add special tokens to tokenizer and resize the token embeddings
|
|
"""
|
|
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
|
|
if newly_added_num > 0:
|
|
self.model.resize_token_embeddings(len(self.tokenizer))
|
|
|
|
def add_tokens(self, list_of_tokens: List[str]):
|
|
"""
|
|
Add special tokens to tokenizer and resize the token embeddings
|
|
"""
|
|
newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
|
|
if newly_added_num > 0:
|
|
self.model.resize_token_embeddings(len(self.tokenizer))
|
|
|
|
def prepare_inputs_for_inference(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
encoder_outputs: torch.Tensor,
|
|
past=None,
|
|
past_key_values=None,
|
|
use_cache: bool = None,
|
|
attention_mask: torch.Tensor = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
input_ids: (batch_size, sequence_length)
|
|
|
|
Returns:
|
|
input_ids: (batch_size, sequence_length)
|
|
attention_mask: (batch_size, sequence_length)
|
|
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
|
|
"""
|
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
|
|
past = past or past_key_values
|
|
if past is not None:
|
|
input_ids = input_ids[:, -1:]
|
|
output = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"past_key_values": past,
|
|
"use_cache": use_cache,
|
|
"encoder_hidden_states": encoder_outputs.last_hidden_state,
|
|
}
|
|
return output
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
use_cache: bool = None,
|
|
output_attentions: Optional[torch.Tensor] = None,
|
|
output_hidden_states: Optional[torch.Tensor] = None,
|
|
return_dict: bool = None,
|
|
):
|
|
return self.model.forward(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
@staticmethod
|
|
def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
|
|
"""
|
|
Resize position embeddings
|
|
Truncate if sequence length of MBart backbone is greater than given max_length,
|
|
else interpolate to max_length
|
|
"""
|
|
if weight.shape[0] > max_length:
|
|
weight = weight[:max_length, ...]
|
|
else:
|
|
weight = (
|
|
F.interpolate(
|
|
weight.permute(1, 0).unsqueeze(0),
|
|
size=max_length,
|
|
mode="linear",
|
|
align_corners=False,
|
|
)
|
|
.squeeze(0)
|
|
.permute(1, 0)
|
|
)
|
|
return weight
|
|
|
|
|
|
class DonutConfig(PretrainedConfig):
|
|
|
|
def __init__(
|
|
self,
|
|
decoder_layer: int = 10,
|
|
max_position_embeddings: int = None,
|
|
max_length: int = 4096,
|
|
hidden_dimension: int = 1024,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.decoder_layer = decoder_layer
|
|
self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
|
|
self.max_length = max_length
|
|
self.hidden_dimension = hidden_dimension
|
|
|
|
|
|
class RunningVarTorch:
|
|
def __init__(self, L=15, norm=False):
|
|
self.values = None
|
|
self.L = L
|
|
self.norm = norm
|
|
|
|
def push(self, x: torch.Tensor):
|
|
assert x.dim() == 1
|
|
if self.values is None:
|
|
self.values = x[:, None]
|
|
elif self.values.shape[1] < self.L:
|
|
self.values = torch.cat((self.values, x[:, None]), 1)
|
|
else:
|
|
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
|
|
|
def variance(self):
|
|
if self.values is None:
|
|
return
|
|
if self.norm:
|
|
return torch.var(self.values, 1) / self.values.shape[1]
|
|
else:
|
|
return torch.var(self.values, 1)
|
|
|
|
|
|
class StoppingCriteriaScores(StoppingCriteria):
|
|
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
|
super().__init__()
|
|
self.threshold = threshold
|
|
self.vars = RunningVarTorch(norm=True)
|
|
self.varvars = RunningVarTorch(L=window_size)
|
|
self.stop_inds = defaultdict(int)
|
|
self.stopped = defaultdict(bool)
|
|
self.size = 0
|
|
self.window_size = window_size
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
|
last_scores = scores[-1]
|
|
self.vars.push(last_scores.max(1)[0].float().cpu())
|
|
self.varvars.push(self.vars.variance())
|
|
self.size += 1
|
|
if self.size < self.window_size:
|
|
return False
|
|
|
|
varvar = self.varvars.variance()
|
|
for b in range(len(last_scores)):
|
|
if varvar[b] < self.threshold:
|
|
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
|
self.stopped[b] = self.stop_inds[b] >= self.size
|
|
else:
|
|
self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
|
|
else:
|
|
self.stop_inds[b] = 0
|
|
self.stopped[b] = False
|
|
return all(self.stopped.values()) and len(self.stopped) > 0
|
|
|
|
|
|
def batch(l, b=15):
|
|
subs = []
|
|
for i in range(len(l) - b):
|
|
subs.append(l[i : i + b])
|
|
return subs
|
|
|
|
|
|
def subdiv(l, b=10):
|
|
subs = []
|
|
for i in range(len(l) - b):
|
|
subs.append(l[: i + b])
|
|
return subs
|
|
|
|
|
|
class DonutModel(PreTrainedModel):
|
|
config_class = DonutConfig
|
|
base_model_prefix = "donut"
|
|
|
|
def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.tokenizer = tokenizer
|
|
self.vpm = vision_tower
|
|
|
|
# build language model
|
|
self.llm = BARTDecoder(
|
|
tokenizer=tokenizer,
|
|
decoder_layer=self.config.decoder_layer,
|
|
max_position_embeddings=self.config.max_position_embeddings,
|
|
hidden_dimension=self.config.hidden_dimension,
|
|
)
|
|
self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
|
|
|
|
def get_input_embeddings(self, tensor):
|
|
return self.llm.model.get_input_embeddings()(tensor)
|
|
|
|
def forward(
|
|
self,
|
|
inputs: dict,
|
|
):
|
|
image_tensors = inputs["pixel_values"]
|
|
input_ids = inputs["input_ids"].contiguous()
|
|
attention_mask = inputs["attention_mask"]
|
|
labels = inputs["labels"].contiguous()
|
|
|
|
encoder_outputs = self.vpm(
|
|
image_tensors,
|
|
text_embedding=self.llm.model.get_input_embeddings()(input_ids),
|
|
)
|
|
|
|
decoder_outputs = self.llm(
|
|
input_ids=input_ids,
|
|
encoder_hidden_states=encoder_outputs,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
)
|
|
return decoder_outputs
|
|
|
|
def get_hidden_states_during_inference(
|
|
self,
|
|
prompt_ids: torch.Tensor,
|
|
image: Image.Image = None,
|
|
image_tensors: Optional[torch.Tensor] = None,
|
|
):
|
|
if image_tensors is None:
|
|
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
|
|
|
if self.device.type != "mps":
|
|
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
|
|
|
image_tensors = image_tensors.to(self.device)
|
|
prompt_ids = prompt_ids.to(self.device)
|
|
all_hidden_states = self.vpm.forward_features(
|
|
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
|
|
)
|
|
return all_hidden_states
|
|
|
|
def get_attn_weights_during_inference(
|
|
self,
|
|
prompt_ids: torch.Tensor,
|
|
image: Image.Image = None,
|
|
image_tensors: Optional[torch.Tensor] = None,
|
|
):
|
|
if image_tensors is None:
|
|
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
|
|
|
if self.device.type != "mps":
|
|
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
|
|
|
image_tensors = image_tensors.to(self.device)
|
|
prompt_ids = prompt_ids.to(self.device)
|
|
last_attn_score = self.vpm.get_last_layer_cross_attn_score(
|
|
image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
|
|
)
|
|
return last_attn_score
|
|
|
|
def inference(
|
|
self,
|
|
prompt_ids: torch.Tensor,
|
|
image: Image.Image = None,
|
|
image_tensors: Optional[torch.Tensor] = None,
|
|
return_attentions: bool = False,
|
|
early_stopping: bool = True,
|
|
):
|
|
"""
|
|
Generate a token sequence in an auto-regressive manner.
|
|
|
|
Args:
|
|
image: input document image (PIL.Image)
|
|
image_tensors: (1, num_channels, height, width)
|
|
convert prompt to tensor if image_tensor is not fed
|
|
"""
|
|
output = {
|
|
"predictions": list(),
|
|
"sequences": list(),
|
|
"repeats": list(),
|
|
"repetitions": list(),
|
|
}
|
|
if image is None and image_tensors is None:
|
|
logging.warn("Image not found")
|
|
return output
|
|
|
|
if image_tensors is None:
|
|
image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
|
|
|
|
if self.device.type != "mps":
|
|
image_tensors = image_tensors.to(next(self.parameters()).dtype)
|
|
|
|
image_tensors = image_tensors.to(self.device)
|
|
prompt_ids = prompt_ids.to(self.device)
|
|
last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
|
|
|
|
encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
|
|
if len(encoder_outputs.last_hidden_state.size()) == 1:
|
|
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
|
|
|
|
# get decoder output
|
|
decoder_output = self.llm.model.generate(
|
|
input_ids=prompt_ids,
|
|
encoder_outputs=encoder_outputs,
|
|
min_length=1,
|
|
max_length=self.config.max_length,
|
|
pad_token_id=self.llm.tokenizer.pad_token_id,
|
|
eos_token_id=self.llm.tokenizer.eos_token_id,
|
|
use_cache=True,
|
|
return_dict_in_generate=True,
|
|
output_scores=True,
|
|
output_attentions=return_attentions,
|
|
do_sample=False,
|
|
num_beams=1,
|
|
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
|
|
)
|
|
|
|
output["repetitions"] = decoder_output.sequences.clone()
|
|
output["sequences"] = decoder_output.sequences.clone()
|
|
output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
|
|
|
|
output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
|
|
return output
|