add dolphin inference by tensorrt-llm
This commit is contained in:
parent
ce591d9136
commit
c247e5e1f3
12
deployment/ReadMe.md
Normal file
12
deployment/ReadMe.md
Normal file
@ -0,0 +1,12 @@
|
||||
<h1 align="center">
|
||||
🚀 Dolphin Inference/Serving
|
||||
</h1>
|
||||
|
||||
## vLLM
|
||||
> [Doc](./vllm/README.md)
|
||||
|
||||
## TensorRT-LLM
|
||||
> [Doc](./tensorrt_llm/README.md)
|
||||
|
||||
## Others
|
||||
|
87
deployment/tensorrt_llm/ReadMe.md
Normal file
87
deployment/tensorrt_llm/ReadMe.md
Normal file
@ -0,0 +1,87 @@
|
||||
<h1 align="center">
|
||||
🚀 Dolphin TensorRT-LLM Demo
|
||||
</h1>
|
||||
|
||||
## ✅ Introduction
|
||||
The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json),
|
||||
its architectures field is specified as "VisionEncoderDecoderModel". Dolphin, Nougat, and Donut share the same model architecture. TensorRT-LLM has already supported the Nougat model.
|
||||
Following Nougat's conversion script, we have successfully implemented Dolphin on TensorRT-LLM. Note: input_ids MUST be of int32 type, otherwise TensorRT-LLM will produce incorrect results.
|
||||
|
||||
## 🛠️ Installation
|
||||
> We only test TensorRT-LLM 0.18.1 on Linux.
|
||||
|
||||
https://nvidia.github.io/TensorRT-LLM/0.18.1/installation/linux.html
|
||||
|
||||
|
||||
## ⚡ Offline Inference
|
||||
```
|
||||
export MODEL_NAME="Dolphin"
|
||||
|
||||
# predict elements reading order
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Parse the reading order of this document." \
|
||||
--image_path "../../demo/page_imgs/page_1.jpeg"
|
||||
|
||||
# recognize text/latex
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Read text in the image." \
|
||||
--image_path "../../demo/element_imgs/block_formula.jpeg"
|
||||
|
||||
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Read text in the image." \
|
||||
--image_path "../../demo/element_imgs/para_1.jpg"
|
||||
|
||||
# recognize table
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Parse the table in the image." \
|
||||
--image_path "../../demo/element_imgs/table_1.jpeg"
|
||||
```
|
||||
|
||||
|
||||
## ⚡ Online Inference
|
||||
```
|
||||
# 1. Start Api Server
|
||||
export MODEL_NAME="Dolphin"
|
||||
|
||||
python api_server.py \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_batch_size 16
|
||||
|
||||
# 2. Predict
|
||||
# predict elements reading order
|
||||
python deployment/tensorrt_llm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
|
||||
|
||||
# recognize text/latex
|
||||
python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
|
||||
python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
|
||||
|
||||
# recognize table
|
||||
python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
|
||||
```
|
100
deployment/tensorrt_llm/api_client.py
Normal file
100
deployment/tensorrt_llm/api_client.py
Normal file
@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for `vllm.entrypoints.api_server`
|
||||
Start the demo server:
|
||||
python -m vllm.entrypoints.api_server --model <model_name>
|
||||
|
||||
NOTE: The API server is used only for demonstration and simple performance
|
||||
benchmarks. It is not intended for production use.
|
||||
For production use, we recommend `vllm serve` and the OpenAI client API.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
from argparse import Namespace
|
||||
from collections.abc import Iterable
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def clear_line(n: int = 1) -> None:
|
||||
LINE_UP = "\033[1A"
|
||||
LINE_CLEAR = "\x1b[2K"
|
||||
for _ in range(n):
|
||||
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||
|
||||
|
||||
def encode_image_base64(image_path: str) -> str:
|
||||
"""Encode local image to base64 format."""
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
result = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def post_http_request(
|
||||
prompt: str, image_path: str, api_url: str, stream: bool = False
|
||||
) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"image_base64": encode_image_base64(image_path),
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
|
||||
return response
|
||||
|
||||
|
||||
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
|
||||
):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"]
|
||||
yield output
|
||||
|
||||
|
||||
def get_response(response: requests.Response) -> list[str]:
|
||||
data = json.loads(response.content)
|
||||
output = data["text"]
|
||||
return output
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
|
||||
parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
prompt = args.prompt
|
||||
image_path = args.image_path
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
stream = args.stream
|
||||
|
||||
print(f"Prompt: {prompt!r}\n", flush=True)
|
||||
response = post_http_request(prompt, image_path, api_url, stream)
|
||||
|
||||
if stream:
|
||||
num_printed_lines = 0
|
||||
for h in get_streaming_response(response):
|
||||
clear_line(num_printed_lines)
|
||||
num_printed_lines = 0
|
||||
for i, line in enumerate(h):
|
||||
num_printed_lines += 1
|
||||
print(f"Response {i}: {line!r}", flush=True)
|
||||
else:
|
||||
output = get_response(response)
|
||||
print(f"Response: {output!r}", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
112
deployment/tensorrt_llm/api_server.py
Normal file
112
deployment/tensorrt_llm/api_server.py
Normal file
@ -0,0 +1,112 @@
|
||||
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
|
||||
|
||||
#!/usr/bin/env python
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import signal
|
||||
from http import HTTPStatus
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from tensorrt_llm.executor import CppExecutorError, RequestError
|
||||
from dolphin_runner import DolphinRunner, InferenceConfig
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
|
||||
|
||||
async def decode_image(image_base64: str) -> Image.Image:
|
||||
image_data = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
return image
|
||||
|
||||
|
||||
class LlmServer:
|
||||
def __init__(self, runner: DolphinRunner):
|
||||
self.runner = runner
|
||||
self.app = FastAPI()
|
||||
self.register_routes()
|
||||
|
||||
def register_routes(self):
|
||||
self.app.add_api_route("/health", self.health, methods=["GET"])
|
||||
self.app.add_api_route("/generate", self.generate, methods=["POST"])
|
||||
|
||||
async def health(self) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
async def generate(self, request: Request) -> Response:
|
||||
""" Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- image_base64: the image to use for the generation.
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
|
||||
prompt = request_dict.pop("prompt", "")
|
||||
logging.info(f"request prompt: {prompt}")
|
||||
image_base64 = request_dict.pop("image_base64", "")
|
||||
image = await decode_image(image_base64)
|
||||
|
||||
try:
|
||||
output_texts = self.runner.run([prompt], [image], 4024)
|
||||
output_texts = [texts[0] for texts in output_texts]
|
||||
return JSONResponse({"text": output_texts[0]})
|
||||
except RequestError as e:
|
||||
return JSONResponse(content=str(e),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
except CppExecutorError:
|
||||
# If internal executor error is raised, shutdown the server
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
|
||||
async def __call__(self, host, port):
|
||||
config = uvicorn.Config(self.app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
await uvicorn.Server(config).serve()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--hf_model_dir", type=str, required=True)
|
||||
@click.option("--visual_engine_dir", type=str, required=True)
|
||||
@click.option("--llm_engine_dir", type=str, required=True)
|
||||
@click.option("--max_batch_size", type=int, default=16)
|
||||
@click.option("--max_new_tokens", type=int, default=4024)
|
||||
@click.option("--host", type=str, default=None)
|
||||
@click.option("--port", type=int, default=8000)
|
||||
def entrypoint(hf_model_dir: str,
|
||||
visual_engine_dir: str,
|
||||
llm_engine_dir: str,
|
||||
max_batch_size: int,
|
||||
max_new_tokens: int,
|
||||
host: Optional[str] = None,
|
||||
port: int = 8000):
|
||||
host = host or "0.0.0.0"
|
||||
port = port or 8000
|
||||
logging.info(f"Starting server at {host}:{port}")
|
||||
|
||||
config = InferenceConfig(
|
||||
max_new_tokens=max_new_tokens,
|
||||
batch_size=max_batch_size,
|
||||
log_level="info",
|
||||
hf_model_dir=hf_model_dir,
|
||||
visual_engine_dir=visual_engine_dir,
|
||||
llm_engine_dir=llm_engine_dir,
|
||||
)
|
||||
|
||||
dolphin_runner = DolphinRunner(config)
|
||||
server = LlmServer(runner=dolphin_runner)
|
||||
|
||||
asyncio.run(server(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
entrypoint()
|
0
deployment/tensorrt_llm/convert/__init__.py
Normal file
0
deployment/tensorrt_llm/convert/__init__.py
Normal file
14
deployment/tensorrt_llm/convert/build_visual_engine.py
Normal file
14
deployment/tensorrt_llm/convert/build_visual_engine.py
Normal file
@ -0,0 +1,14 @@
|
||||
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.2/examples/multimodal/build_visual_engine.py
|
||||
|
||||
import argparse
|
||||
|
||||
from tensorrt_llm.tools.multimodal_builder import (VisionEngineBuilder,
|
||||
add_multimodal_arguments)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = add_multimodal_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
builder = VisionEngineBuilder(args)
|
||||
builder.build()
|
1528
deployment/tensorrt_llm/convert/convert_checkpoint.py
Normal file
1528
deployment/tensorrt_llm/convert/convert_checkpoint.py
Normal file
File diff suppressed because it is too large
Load Diff
95
deployment/tensorrt_llm/convert/helper.py
Normal file
95
deployment/tensorrt_llm/convert/helper.py
Normal file
@ -0,0 +1,95 @@
|
||||
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/helper.py
|
||||
|
||||
import typing
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch # pytype: disable=import-error
|
||||
|
||||
from tensorrt_llm._utils import str_dtype_to_torch
|
||||
|
||||
|
||||
def split(v: Union[np.ndarray, torch.Tensor],
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
dim=0):
|
||||
if tp_size == 1:
|
||||
if isinstance(v, np.ndarray):
|
||||
return np.ascontiguousarray(v.copy())
|
||||
else:
|
||||
return v.clone().detach()
|
||||
assert len(v.shape) > 1 or dim == 0
|
||||
if isinstance(v, np.ndarray):
|
||||
return np.ascontiguousarray(
|
||||
np.split(v, tp_size, axis=dim)[tp_rank].copy())
|
||||
else:
|
||||
assert v.shape[dim] % tp_size == 0, \
|
||||
'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
|
||||
split_size = v.shape[dim] // tp_size
|
||||
return v.split(split_size, dim=dim)[tp_rank].clone().detach()
|
||||
|
||||
|
||||
def reshape(v: torch.Tensor, shape=None):
|
||||
if shape is None:
|
||||
return v.contiguous()
|
||||
else:
|
||||
return v.reshape(shape).contiguous()
|
||||
|
||||
|
||||
def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,
|
||||
tp_rank, model_type, weight_shape, bias_shape):
|
||||
|
||||
qkv_module_names = get_qkv_module_name(model_type)
|
||||
|
||||
weight = {}
|
||||
|
||||
# fuse weights of q, k, v
|
||||
q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight']
|
||||
k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight']
|
||||
v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight']
|
||||
|
||||
# fuse qkv weight
|
||||
shape = q_w.shape # (do, din)
|
||||
qkv_w = torch.cat([q_w, k_w, v_w],
|
||||
dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din)
|
||||
qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)
|
||||
weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,
|
||||
shape=weight_shape)
|
||||
|
||||
# fuse qkv biases if present
|
||||
if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys(
|
||||
) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None:
|
||||
q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias']
|
||||
k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias']
|
||||
v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias']
|
||||
shape = q_b.shape[0] # (do,)
|
||||
qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do)
|
||||
qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)
|
||||
weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,
|
||||
shape=bias_shape)
|
||||
return weight
|
||||
|
||||
|
||||
def get_qkv_module_name(model_type):
|
||||
if model_type in ["t5", "blip2"]:
|
||||
q = "q"
|
||||
k = "k"
|
||||
v = "v"
|
||||
elif model_type == "bart" or model_type == "nmt":
|
||||
q = "q_proj"
|
||||
k = "k_proj"
|
||||
v = "v_proj"
|
||||
elif model_type == "pix2struct":
|
||||
q = "query"
|
||||
k = "key"
|
||||
v = "value"
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
|
||||
def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
|
||||
dtype: typing.Optional[np.dtype] = None):
|
||||
if dtype is not None:
|
||||
assert isinstance(dtype,
|
||||
str), f"dtype must be str, but get type {type(dtype)}"
|
||||
for name in params.keys():
|
||||
params[name] = params[name].to(str_dtype_to_torch(dtype))
|
47
deployment/tensorrt_llm/convert_dolphin.sh
Normal file
47
deployment/tensorrt_llm/convert_dolphin.sh
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
############################################################################################
|
||||
# Reference: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.18.2/examples/multimodal#nougat
|
||||
############################################################################################
|
||||
|
||||
export LD_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/tensorrt_libs/:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH
|
||||
|
||||
# 1. Download Huggingface weights
|
||||
export MODEL_NAME="Dolphin"
|
||||
git clone https://huggingface.co/Bytedance/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
|
||||
|
||||
|
||||
export MAX_BATCH_SIZE=16
|
||||
export MAX_SEQ_LEN=4096
|
||||
export MAX_INPUT_LEN=10
|
||||
export MAX_ENCODER_INPUT_LEN=784
|
||||
|
||||
# 2. Convert Huggingface weights into TRT-LLM checkpoints and build TRT engines using scripts in examples/enc_dec
|
||||
python ./convert/convert_checkpoint.py --model_type bart \
|
||||
--model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \
|
||||
--tp_size 1 \
|
||||
--pp_size 1 \
|
||||
--dtype bfloat16 \
|
||||
--nougat
|
||||
|
||||
|
||||
trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \
|
||||
--output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \
|
||||
--paged_kv_cache disable \
|
||||
--moe_plugin disable \
|
||||
--gemm_plugin bfloat16 \
|
||||
--bert_attention_plugin bfloat16 \
|
||||
--gpt_attention_plugin bfloat16 \
|
||||
--remove_input_padding enable \
|
||||
--max_beam_width 1 \
|
||||
--max_batch_size ${MAX_BATCH_SIZE} \
|
||||
--max_seq_len ${MAX_SEQ_LEN} \
|
||||
--max_input_len ${MAX_INPUT_LEN} \
|
||||
--max_encoder_input_len $((${MAX_BATCH_SIZE} * ${MAX_ENCODER_INPUT_LEN})) # MAX_BATCH_SIZE (max_batch_size) * MAX_ENCODER_INPUT_LEN (num_visual_features)
|
||||
|
||||
# 3. Generate TensorRT engines for visual components and combine everything into final pipeline.
|
||||
python ./convert/build_visual_engine.py --model_type nougat \
|
||||
--model_path tmp/hf_models/${MODEL_NAME} \
|
||||
--max_batch_size ${MAX_BATCH_SIZE}
|
219
deployment/tensorrt_llm/dolphin_runner.py
Normal file
219
deployment/tensorrt_llm/dolphin_runner.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""
|
||||
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.profiler as profiler
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm import mpi_rank
|
||||
from tensorrt_llm.runtime import MultimodalModelRunner
|
||||
from transformers import AutoTokenizer, DonutProcessor
|
||||
|
||||
|
||||
class InferenceConfig(BaseModel):
|
||||
max_new_tokens: int = Field(128, description="Maximum new tokens to generate")
|
||||
batch_size: int = Field(1, description="Batch size for inference")
|
||||
log_level: str = Field("info", description="Logging level")
|
||||
visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files")
|
||||
visual_engine_name: str = Field("model.engine", description="Visual engine filename")
|
||||
llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files")
|
||||
hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory")
|
||||
input_text: Optional[str] = Field(None, description="Input text for inference")
|
||||
num_beams: int = Field(1, description="Number of beams for beam search")
|
||||
top_k: int = Field(1, description="Top-k sampling value")
|
||||
top_p: float = Field(0.0, description="Top-p (nucleus) sampling value")
|
||||
temperature: float = Field(1.0, description="Sampling temperature")
|
||||
repetition_penalty: float = Field(1.0, description="Repetition penalty factor")
|
||||
run_profiling: bool = Field(False, description="Enable profiling mode")
|
||||
profiling_iterations: int = Field(20, description="Number of profiling iterations")
|
||||
check_accuracy: bool = Field(False, description="Enable accuracy checking")
|
||||
video_path: Optional[str] = Field(None, description="Path to input video file")
|
||||
video_num_frames: Optional[int] = Field(None, description="Number of video frames to process")
|
||||
image_path: Optional[str] = Field(None, description="Path to input image file")
|
||||
path_sep: str = Field(",", description="Path separator character")
|
||||
prompt_sep: str = Field(",", description="Prompt separator character")
|
||||
enable_context_fmha_fp32_acc: Optional[bool] = Field(
|
||||
None,
|
||||
description="Enable FP32 accumulation for context FMHA"
|
||||
)
|
||||
enable_chunked_context: bool = Field(False, description="Enable chunked context processing")
|
||||
use_py_session: bool = Field(False, description="Use Python session instead of C++")
|
||||
kv_cache_free_gpu_memory_fraction: float = Field(
|
||||
0.9,
|
||||
description="Fraction of GPU memory free for KV cache",
|
||||
ge=0.0, le=1.0
|
||||
)
|
||||
cross_kv_cache_fraction: float = Field(
|
||||
0.5,
|
||||
description="Fraction of cross-attention KV cache",
|
||||
ge=0.0, le=1.0
|
||||
)
|
||||
multi_block_mode: bool = Field(True, description="Enable multi-block processing mode")
|
||||
|
||||
|
||||
class DolphinRunner(MultimodalModelRunner):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
self.runtime_rank = mpi_rank()
|
||||
device_id = self.runtime_rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device_id)
|
||||
self.device = "cuda:%d" % (device_id)
|
||||
|
||||
self.stream = torch.cuda.Stream(torch.cuda.current_device())
|
||||
torch.cuda.set_stream(self.stream)
|
||||
|
||||
# parse model type from visual engine config
|
||||
with open(os.path.join(self.args.visual_engine_dir, "config.json"),
|
||||
"r") as f:
|
||||
config = json.load(f)
|
||||
self.model_type = config['builder_config']['model_type']
|
||||
self.vision_precision = config['builder_config']['precision']
|
||||
self.decoder_llm = not (
|
||||
't5' in self.model_type
|
||||
or self.model_type in ['nougat', 'pix2struct']
|
||||
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
|
||||
|
||||
if self.model_type == "mllama":
|
||||
self.vision_input_names = [
|
||||
"pixel_values",
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
]
|
||||
self.vision_output_names = [
|
||||
"output",
|
||||
]
|
||||
else:
|
||||
self.vision_input_names = ["input"]
|
||||
self.vision_output_names = ["output"]
|
||||
|
||||
self.use_py_session = True
|
||||
|
||||
self.init_image_encoder()
|
||||
self.init_tokenizer()
|
||||
self.init_processor()
|
||||
self.init_llm()
|
||||
|
||||
def init_tokenizer(self):
|
||||
assert self.model_type == 'nougat'
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir)
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
def init_processor(self):
|
||||
assert self.model_type == 'nougat'
|
||||
self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True)
|
||||
|
||||
def run(self, input_texts, input_images, max_new_tokens):
|
||||
prompts = [f"<s>{text.strip()} <Answer/>" for text in input_texts]
|
||||
images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda")
|
||||
prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda")
|
||||
prompt_ids = prompt_ids.to(
|
||||
torch.int32) # Important! If the type of prompt_ids is not int32, the output will be wrong.
|
||||
|
||||
logger.info("---------------------------------------------------------")
|
||||
logger.info(f"images size: {images.size()}")
|
||||
logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}")
|
||||
logger.info("---------------------------------------------------------")
|
||||
|
||||
output_texts = self.generate(input_texts,
|
||||
[None] * len(input_texts),
|
||||
images,
|
||||
prompt_ids,
|
||||
max_new_tokens,
|
||||
warmup=False,
|
||||
)
|
||||
|
||||
return output_texts
|
||||
|
||||
def generate(self,
|
||||
pre_prompt,
|
||||
post_prompt,
|
||||
image,
|
||||
decoder_input_ids,
|
||||
max_new_tokens,
|
||||
warmup=False,
|
||||
other_vision_inputs={},
|
||||
other_decoder_inputs={}):
|
||||
if not warmup:
|
||||
profiler.start("Generate")
|
||||
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
|
||||
warmup, pre_prompt, post_prompt, image, other_vision_inputs)
|
||||
|
||||
if warmup: return None
|
||||
|
||||
# use prompt tuning to pass multimodal features
|
||||
# model.generate() expects the following params (see layers/embedding.py):
|
||||
# args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size]
|
||||
# args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated)
|
||||
# args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len
|
||||
profiler.start("LLM")
|
||||
if self.model_type in ['nougat', 'pix2struct']:
|
||||
# Trim encoder input_ids to match visual features shape
|
||||
ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1])
|
||||
if self.model_type == 'nougat':
|
||||
input_ids = torch.zeros(ids_shape, dtype=torch.int32)
|
||||
elif self.model_type == 'pix2struct':
|
||||
input_ids = torch.ones(ids_shape, dtype=torch.int32)
|
||||
|
||||
output_ids = self.model.generate(
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
max_new_tokens,
|
||||
num_beams=self.args.num_beams,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
debug_mode=False,
|
||||
prompt_embedding_table=ptuning_args[0],
|
||||
prompt_tasks=ptuning_args[1],
|
||||
prompt_vocab_size=ptuning_args[2],
|
||||
)
|
||||
profiler.stop("LLM")
|
||||
|
||||
if mpi_rank() == 0:
|
||||
# Extract a list of tensors of shape beam_width x output_ids.
|
||||
output_beams_list = [
|
||||
self.tokenizer.batch_decode(
|
||||
output_ids[batch_idx, :, decoder_input_ids.shape[1]:],
|
||||
skip_special_tokens=False) for batch_idx in range(
|
||||
min(self.args.batch_size, decoder_input_ids.shape[0]))
|
||||
]
|
||||
|
||||
stripped_text = [[
|
||||
output_beams_list[batch_idx][beam_idx].replace("</s>", "").replace("<pad>", "").strip()
|
||||
for beam_idx in range(self.args.num_beams)
|
||||
] for batch_idx in range(
|
||||
min(self.args.batch_size, decoder_input_ids.shape[0]))]
|
||||
profiler.stop("Generate")
|
||||
return stripped_text
|
||||
else:
|
||||
profiler.stop("Generate")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = InferenceConfig(
|
||||
max_new_tokens=4024,
|
||||
batch_size=16,
|
||||
log_level="info",
|
||||
hf_model_dir=f"./tmp/hf_models/Dolphin",
|
||||
visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder",
|
||||
llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16",
|
||||
)
|
||||
|
||||
model = DolphinRunner(config)
|
||||
|
||||
image_path = "../../demo/page_imgs/page_1.jpeg"
|
||||
prompt = "Parse the reading order of this document."
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
output_texts = model.run([prompt], [image], 4024)
|
||||
output_texts = [texts[0] for texts in output_texts]
|
||||
print(output_texts)
|
106
deployment/tensorrt_llm/run_dolphin.py
Normal file
106
deployment/tensorrt_llm/run_dolphin.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""
|
||||
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.profiler as profiler
|
||||
from PIL import Image
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm import mpi_rank
|
||||
from tensorrt_llm.runtime import MultimodalModelRunner
|
||||
|
||||
from dolphin_runner import DolphinRunner
|
||||
from utils import add_common_args
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def print_result(model, input_text, output_text, args):
|
||||
logger.info("---------------------------------------------------------")
|
||||
logger.info(f"\n[Q] {input_text}")
|
||||
for i in range(len(output_text)):
|
||||
logger.info(f"\n[A]: {output_text[i]}")
|
||||
|
||||
if args.num_beams == 1:
|
||||
output_ids = model.tokenizer(output_text[0][0],
|
||||
add_special_tokens=False)['input_ids']
|
||||
logger.info(f"Generated {len(output_ids)} tokens")
|
||||
|
||||
if args.check_accuracy:
|
||||
if model.model_type != 'nougat':
|
||||
if model.model_type == "vila":
|
||||
for i in range(len(args.image_path.split(args.path_sep))):
|
||||
if i % 2 == 0:
|
||||
assert output_text[i][0].lower(
|
||||
) == "the image captures a bustling city intersection teeming with life. from the perspective of a car's dashboard camera, we see"
|
||||
else:
|
||||
assert output_text[i][0].lower(
|
||||
) == "the image captures the iconic merlion statue in singapore, a renowned worldwide landmark. the merlion, a mythical"
|
||||
elif model.model_type == "llava":
|
||||
for i in range(len(args.image_path.split(args.path_sep))):
|
||||
assert output_text[i][0].lower() == 'singapore'
|
||||
elif model.model_type == 'fuyu':
|
||||
assert output_text[0][0].lower() == '4'
|
||||
elif model.model_type == "pix2struct":
|
||||
assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[
|
||||
0][0].lower()
|
||||
elif model.model_type in [
|
||||
'blip2', 'neva', 'phi-3-vision', 'llava_next'
|
||||
]:
|
||||
assert 'singapore' in output_text[0][0].lower()
|
||||
elif model.model_type == 'video-neva':
|
||||
assert 'robot' in output_text[0][0].lower()
|
||||
elif model.model_type == 'kosmos-2':
|
||||
assert 'snowman' in output_text[0][0].lower()
|
||||
elif model.model_type == "mllama":
|
||||
if "If I had to write a haiku for this one" in input_text:
|
||||
assert "it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a" in output_text[
|
||||
0][0] or "Here is a haiku for the image:\n\n" in output_text[
|
||||
0][0], f"expected results: 'it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a', generated results: '{output_text[0][0]}'"
|
||||
elif "The key to life is" in input_text:
|
||||
assert "to find your passion and pursue it with all your heart." in output_text[
|
||||
0][0] or "not to be found in the external world," in output_text[
|
||||
0][0], f"expected results: 'to find your passion and pursue it with all your heart.', generated results: '{output_text[0][0]}'"
|
||||
elif model.model_type == 'llava_onevision':
|
||||
if args.video_path is None:
|
||||
assert 'singapore' in output_text[0][0].lower()
|
||||
else:
|
||||
assert 'the video is funny because the child\'s actions are' in output_text[
|
||||
0][0].lower()
|
||||
elif model.model_type == "qwen2_vl":
|
||||
assert 'dog' in output_text[0][0].lower()
|
||||
else:
|
||||
assert output_text[0][0].lower() == 'singapore'
|
||||
|
||||
if args.run_profiling:
|
||||
msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(
|
||||
name) / args.profiling_iterations
|
||||
logger.info('Latencies per batch (msec)')
|
||||
logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision')))
|
||||
logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM')))
|
||||
logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate')))
|
||||
|
||||
logger.info("---------------------------------------------------------")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = add_common_args(parser)
|
||||
args = parser.parse_args()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
model = DolphinRunner(args)
|
||||
|
||||
input_image = Image.open(args.image_path[0]).convert('RGB')
|
||||
num_iters = args.profiling_iterations if args.run_profiling else 1
|
||||
|
||||
for _ in range(num_iters):
|
||||
output_texts = model.run(args.input_text, [input_image], args.max_new_tokens)
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
if runtime_rank == 0:
|
||||
print_result(model, args.input_text, output_texts, args)
|
47
deployment/tensorrt_llm/run_dolphin.sh
Normal file
47
deployment/tensorrt_llm/run_dolphin.sh
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
export MODEL_NAME="Dolphin"
|
||||
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Parse the reading order of this document." \
|
||||
--image_path "../../demo/page_imgs/page_1.jpeg"
|
||||
|
||||
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Read text in the image." \
|
||||
--image_path "../../demo/element_imgs/block_formula.jpeg"
|
||||
|
||||
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Read text in the image." \
|
||||
--image_path "../../demo/element_imgs/para_1.jpg"
|
||||
|
||||
|
||||
python run_dolphin.py \
|
||||
--batch_size 1 \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_new_tokens 4096 \
|
||||
--repetition_penalty 1.0 \
|
||||
--input_text "Parse the table in the image." \
|
||||
--image_path "../../demo/element_imgs/table_1.jpeg"
|
10
deployment/tensorrt_llm/start_dolphin_server.sh
Normal file
10
deployment/tensorrt_llm/start_dolphin_server.sh
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
export MODEL_NAME="Dolphin"
|
||||
|
||||
python api_server.py \
|
||||
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
||||
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
||||
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
||||
--max_batch_size 16
|
95
deployment/tensorrt_llm/utils.py
Normal file
95
deployment/tensorrt_llm/utils.py
Normal file
@ -0,0 +1,95 @@
|
||||
def add_common_args(parser):
|
||||
parser.add_argument('--max_new_tokens', type=int, default=128)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--visual_engine_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Directory containing visual TRT engines')
|
||||
parser.add_argument('--visual_engine_name',
|
||||
type=str,
|
||||
default='model.engine',
|
||||
help='Name of visual TRT engine')
|
||||
parser.add_argument('--llm_engine_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Directory containing TRT-LLM engines')
|
||||
parser.add_argument('--hf_model_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory containing tokenizer")
|
||||
parser.add_argument('--input_text',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Text prompt to LLM')
|
||||
parser.add_argument('--num_beams',
|
||||
type=int,
|
||||
help="Use beam search if num_beams >1",
|
||||
default=1)
|
||||
parser.add_argument('--top_k', type=int, default=1)
|
||||
parser.add_argument('--top_p', type=float, default=0.0)
|
||||
parser.add_argument('--temperature', type=float, default=1.0)
|
||||
parser.add_argument('--repetition_penalty', type=float, default=1.0)
|
||||
parser.add_argument('--run_profiling',
|
||||
action='store_true',
|
||||
help='Profile runtime over several iterations')
|
||||
parser.add_argument('--profiling_iterations',
|
||||
type=int,
|
||||
help="Number of iterations to run profiling",
|
||||
default=20)
|
||||
parser.add_argument('--check_accuracy',
|
||||
action='store_true',
|
||||
help='Check correctness of text output')
|
||||
parser.add_argument("--image_path",
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='List of input image paths, separated by symbol')
|
||||
parser.add_argument("--path_sep",
|
||||
type=str,
|
||||
default=",",
|
||||
help='Path separator symbol')
|
||||
parser.add_argument("--prompt_sep",
|
||||
type=str,
|
||||
default=",",
|
||||
help="Prompt separator symbol")
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
action='store_true',
|
||||
default=None,
|
||||
help="Enable FMHA runner FP32 accumulation.")
|
||||
parser.add_argument(
|
||||
'--enable_chunked_context',
|
||||
action='store_true',
|
||||
help='Enables chunked context (only available with cpp session).',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_py_session',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
"Whether or not to use Python runtime session. By default C++ runtime session is used for the LLM."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--kv_cache_free_gpu_memory_fraction',
|
||||
default=0.9,
|
||||
type=float,
|
||||
help='Specify the free gpu memory fraction.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cross_kv_cache_fraction',
|
||||
default=0.5,
|
||||
type=float,
|
||||
help=
|
||||
'Specify the kv cache fraction reserved for cross attention. Only applicable for encoder-decoder models. By default 0.5 for self and 0.5 for cross.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
type=lambda s: s.lower() in
|
||||
("yes", "true", "t", "1"
|
||||
), # custom boolean function to convert input string to boolean
|
||||
default=True,
|
||||
help=
|
||||
"Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel."
|
||||
)
|
||||
return parser
|
Loading…
Reference in New Issue
Block a user