Dolphin/utils/processor.py
2025-06-26 19:45:12 +08:00

72 lines
2.5 KiB
Python

"""
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""
import numpy as np
import torch
from PIL import ImageOps
from torchvision import transforms
from torchvision.transforms.functional import resize
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
class DolphinProcessor:
def __init__(
self,
dp_config,
tokenizer,
**kwargs,
) -> None:
self.tokenizer = tokenizer
transform_args = kwargs.get("transform_args", {})
self.max_length = transform_args.get("max_length", 2048)
self.input_size = transform_args.get("input_size", [896, 896]) # height, width
if isinstance(self.input_size, int):
self.input_size = [self.input_size, self.input_size]
try:
self.answer_start_token = self.tokenizer._prompt_end_token
except AttributeError as err:
print('No answer_start_token found, use "" instead')
self.answer_start_token = ""
self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)]
)
def process_prompt_for_inference(self, prompt):
prompt = prompt.replace("<image>\n", "")
if not prompt.startswith("<s>"):
prompt = "<s>" + prompt
message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
return ids.unsqueeze(0)
def process_image_for_inference(self, image, return_img_size=False):
image = resize(image, min(self.input_size))
image.thumbnail((self.input_size[1], self.input_size[0]))
origin_w, origin_h = image.size
delta_width = self.input_size[1] - image.width
delta_height = self.input_size[0] - image.height
pad_width = delta_width // 2
pad_height = delta_height // 2
padding = (
pad_width,
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
image = ImageOps.expand(image, padding)
if return_img_size:
return self.transform(image).unsqueeze(0), (origin_w, origin_h)
return self.transform(image).unsqueeze(0)