221 lines
9.5 KiB
Python
221 lines
9.5 KiB
Python
"""
|
|
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
SPDX-License-Identifier: MIT
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
from typing import Optional
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.profiler as profiler
|
|
import torch
|
|
from PIL import Image
|
|
from pydantic import BaseModel, Field
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm import mpi_rank
|
|
from tensorrt_llm.runtime import MultimodalModelRunner
|
|
from transformers import AutoTokenizer, DonutProcessor
|
|
|
|
|
|
class InferenceConfig(BaseModel):
|
|
max_new_tokens: int = Field(128, description="Maximum new tokens to generate")
|
|
batch_size: int = Field(1, description="Batch size for inference")
|
|
log_level: str = Field("info", description="Logging level")
|
|
visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files")
|
|
visual_engine_name: str = Field("model.engine", description="Visual engine filename")
|
|
llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files")
|
|
hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory")
|
|
input_text: Optional[str] = Field(None, description="Input text for inference")
|
|
num_beams: int = Field(1, description="Number of beams for beam search")
|
|
top_k: int = Field(1, description="Top-k sampling value")
|
|
top_p: float = Field(0.0, description="Top-p (nucleus) sampling value")
|
|
temperature: float = Field(1.0, description="Sampling temperature")
|
|
repetition_penalty: float = Field(1.0, description="Repetition penalty factor")
|
|
run_profiling: bool = Field(False, description="Enable profiling mode")
|
|
profiling_iterations: int = Field(20, description="Number of profiling iterations")
|
|
check_accuracy: bool = Field(False, description="Enable accuracy checking")
|
|
video_path: Optional[str] = Field(None, description="Path to input video file")
|
|
video_num_frames: Optional[int] = Field(None, description="Number of video frames to process")
|
|
image_path: Optional[str] = Field(None, description="Path to input image file")
|
|
path_sep: str = Field(",", description="Path separator character")
|
|
prompt_sep: str = Field(",", description="Prompt separator character")
|
|
enable_context_fmha_fp32_acc: Optional[bool] = Field(
|
|
None,
|
|
description="Enable FP32 accumulation for context FMHA"
|
|
)
|
|
enable_chunked_context: bool = Field(False, description="Enable chunked context processing")
|
|
use_py_session: bool = Field(False, description="Use Python session instead of C++")
|
|
kv_cache_free_gpu_memory_fraction: float = Field(
|
|
0.9,
|
|
description="Fraction of GPU memory free for KV cache",
|
|
ge=0.0, le=1.0
|
|
)
|
|
cross_kv_cache_fraction: float = Field(
|
|
0.5,
|
|
description="Fraction of cross-attention KV cache",
|
|
ge=0.0, le=1.0
|
|
)
|
|
multi_block_mode: bool = Field(True, description="Enable multi-block processing mode")
|
|
|
|
|
|
class DolphinRunner(MultimodalModelRunner):
|
|
def __init__(self, args):
|
|
self.args = args
|
|
|
|
self.runtime_rank = mpi_rank()
|
|
device_id = self.runtime_rank % torch.cuda.device_count()
|
|
torch.cuda.set_device(device_id)
|
|
self.device = "cuda:%d" % (device_id)
|
|
|
|
self.stream = torch.cuda.Stream(torch.cuda.current_device())
|
|
torch.cuda.set_stream(self.stream)
|
|
|
|
# parse model type from visual engine config
|
|
with open(os.path.join(self.args.visual_engine_dir, "config.json"),
|
|
"r") as f:
|
|
config = json.load(f)
|
|
self.model_type = config['builder_config']['model_type']
|
|
self.vision_precision = config['builder_config']['precision']
|
|
self.decoder_llm = not (
|
|
't5' in self.model_type
|
|
or self.model_type in ['nougat', 'pix2struct']
|
|
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
|
|
|
|
if self.model_type == "mllama":
|
|
self.vision_input_names = [
|
|
"pixel_values",
|
|
"aspect_ratio_ids",
|
|
"aspect_ratio_mask",
|
|
]
|
|
self.vision_output_names = [
|
|
"output",
|
|
]
|
|
else:
|
|
self.vision_input_names = ["input"]
|
|
self.vision_output_names = ["output"]
|
|
|
|
self.use_py_session = True
|
|
|
|
self.init_image_encoder()
|
|
self.init_tokenizer()
|
|
self.init_processor()
|
|
self.init_llm()
|
|
|
|
def init_tokenizer(self):
|
|
assert self.model_type == 'nougat'
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir)
|
|
self.tokenizer.padding_side = "right"
|
|
|
|
def init_processor(self):
|
|
assert self.model_type == 'nougat'
|
|
self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True)
|
|
|
|
def run(self, input_texts, input_images, max_new_tokens):
|
|
prompts = [f"<s>{text.strip()} <Answer/>" for text in input_texts]
|
|
images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda")
|
|
prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda")
|
|
|
|
# 🚨🚨🚨 Important! If the type of prompt_ids is not int32, the output will be wrong. 🚨🚨🚨
|
|
prompt_ids = prompt_ids.to(torch.int32)
|
|
|
|
logger.info("---------------------------------------------------------")
|
|
logger.info(f"images size: {images.size()}")
|
|
logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}")
|
|
logger.info("---------------------------------------------------------")
|
|
|
|
output_texts = self.generate(input_texts,
|
|
[None] * len(input_texts),
|
|
images,
|
|
prompt_ids,
|
|
max_new_tokens,
|
|
warmup=False,
|
|
)
|
|
|
|
return output_texts
|
|
|
|
def generate(self,
|
|
pre_prompt,
|
|
post_prompt,
|
|
image,
|
|
decoder_input_ids,
|
|
max_new_tokens,
|
|
warmup=False,
|
|
other_vision_inputs={},
|
|
other_decoder_inputs={}):
|
|
if not warmup:
|
|
profiler.start("Generate")
|
|
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
|
|
warmup, pre_prompt, post_prompt, image, other_vision_inputs)
|
|
|
|
if warmup: return None
|
|
|
|
# use prompt tuning to pass multimodal features
|
|
# model.generate() expects the following params (see layers/embedding.py):
|
|
# args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size]
|
|
# args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated)
|
|
# args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len
|
|
profiler.start("LLM")
|
|
if self.model_type in ['nougat', 'pix2struct']:
|
|
# Trim encoder input_ids to match visual features shape
|
|
ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1])
|
|
if self.model_type == 'nougat':
|
|
input_ids = torch.zeros(ids_shape, dtype=torch.int32)
|
|
elif self.model_type == 'pix2struct':
|
|
input_ids = torch.ones(ids_shape, dtype=torch.int32)
|
|
|
|
output_ids = self.model.generate(
|
|
input_ids,
|
|
decoder_input_ids,
|
|
max_new_tokens,
|
|
num_beams=self.args.num_beams,
|
|
bos_token_id=self.tokenizer.bos_token_id,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
debug_mode=False,
|
|
prompt_embedding_table=ptuning_args[0],
|
|
prompt_tasks=ptuning_args[1],
|
|
prompt_vocab_size=ptuning_args[2],
|
|
)
|
|
profiler.stop("LLM")
|
|
|
|
if mpi_rank() == 0:
|
|
# Extract a list of tensors of shape beam_width x output_ids.
|
|
output_beams_list = [
|
|
self.tokenizer.batch_decode(
|
|
output_ids[batch_idx, :, decoder_input_ids.shape[1]:],
|
|
skip_special_tokens=False) for batch_idx in range(
|
|
min(self.args.batch_size, decoder_input_ids.shape[0]))
|
|
]
|
|
|
|
stripped_text = [[
|
|
output_beams_list[batch_idx][beam_idx].replace("</s>", "").replace("<pad>", "").strip()
|
|
for beam_idx in range(self.args.num_beams)
|
|
] for batch_idx in range(
|
|
min(self.args.batch_size, decoder_input_ids.shape[0]))]
|
|
profiler.stop("Generate")
|
|
return stripped_text
|
|
else:
|
|
profiler.stop("Generate")
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = InferenceConfig(
|
|
max_new_tokens=4024,
|
|
batch_size=16,
|
|
log_level="info",
|
|
hf_model_dir=f"./tmp/hf_models/Dolphin",
|
|
visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder",
|
|
llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16",
|
|
)
|
|
|
|
model = DolphinRunner(config)
|
|
|
|
image_path = "../../demo/page_imgs/page_1.jpeg"
|
|
prompt = "Parse the reading order of this document."
|
|
image = Image.open(image_path).convert("RGB")
|
|
output_texts = model.run([prompt], [image], 4024)
|
|
output_texts = [texts[0] for texts in output_texts]
|
|
print(output_texts)
|