diff --git a/deployment/ReadMe.md b/deployment/ReadMe.md
new file mode 100644
index 0000000..2acd243
--- /dev/null
+++ b/deployment/ReadMe.md
@@ -0,0 +1,12 @@
+
+🚀 Dolphin Inference/Serving
+
+
+## vLLM
+> [Doc](./vllm/README.md)
+
+## TensorRT-LLM
+> [Doc](./tensorrt_llm/README.md)
+
+## Others
+
diff --git a/deployment/tensorrt_llm/ReadMe.md b/deployment/tensorrt_llm/ReadMe.md
new file mode 100644
index 0000000..240ca94
--- /dev/null
+++ b/deployment/tensorrt_llm/ReadMe.md
@@ -0,0 +1,87 @@
+
+🚀 Dolphin TensorRT-LLM Demo
+
+
+## ✅ Introduction
+The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json),
+its architectures field is specified as "VisionEncoderDecoderModel". Dolphin, Nougat, and Donut share the same model architecture. TensorRT-LLM has already supported the Nougat model.
+Following Nougat's conversion script, we have successfully implemented Dolphin on TensorRT-LLM. Note: input_ids MUST be of int32 type, otherwise TensorRT-LLM will produce incorrect results.
+
+## 🛠️ Installation
+> We only test TensorRT-LLM 0.18.1 on Linux.
+
+https://nvidia.github.io/TensorRT-LLM/0.18.1/installation/linux.html
+
+
+## ⚡ Offline Inference
+```
+export MODEL_NAME="Dolphin"
+
+# predict elements reading order
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Parse the reading order of this document." \
+ --image_path "../../demo/page_imgs/page_1.jpeg"
+
+# recognize text/latex
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Read text in the image." \
+ --image_path "../../demo/element_imgs/block_formula.jpeg"
+
+
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Read text in the image." \
+ --image_path "../../demo/element_imgs/para_1.jpg"
+
+# recognize table
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Parse the table in the image." \
+ --image_path "../../demo/element_imgs/table_1.jpeg"
+```
+
+
+## ⚡ Online Inference
+```
+# 1. Start Api Server
+export MODEL_NAME="Dolphin"
+
+python api_server.py \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_batch_size 16
+
+# 2. Predict
+# predict elements reading order
+python deployment/tensorrt_llm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
+
+# recognize text/latex
+python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
+python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
+
+# recognize table
+python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
+```
\ No newline at end of file
diff --git a/deployment/tensorrt_llm/api_client.py b/deployment/tensorrt_llm/api_client.py
new file mode 100644
index 0000000..75107ed
--- /dev/null
+++ b/deployment/tensorrt_llm/api_client.py
@@ -0,0 +1,100 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Example Python client for `vllm.entrypoints.api_server`
+Start the demo server:
+ python -m vllm.entrypoints.api_server --model
+
+NOTE: The API server is used only for demonstration and simple performance
+benchmarks. It is not intended for production use.
+For production use, we recommend `vllm serve` and the OpenAI client API.
+"""
+
+import argparse
+import base64
+import json
+from argparse import Namespace
+from collections.abc import Iterable
+
+import requests
+
+
+def clear_line(n: int = 1) -> None:
+ LINE_UP = "\033[1A"
+ LINE_CLEAR = "\x1b[2K"
+ for _ in range(n):
+ print(LINE_UP, end=LINE_CLEAR, flush=True)
+
+
+def encode_image_base64(image_path: str) -> str:
+ """Encode local image to base64 format."""
+
+ with open(image_path, "rb") as f:
+ image_data = f.read()
+ result = base64.b64encode(image_data).decode("utf-8")
+
+ return result
+
+
+def post_http_request(
+ prompt: str, image_path: str, api_url: str, stream: bool = False
+) -> requests.Response:
+ headers = {"User-Agent": "Test Client"}
+ pload = {
+ "prompt": prompt,
+ "image_base64": encode_image_base64(image_path),
+ }
+ response = requests.post(api_url, headers=headers, json=pload, stream=stream)
+ return response
+
+
+def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
+ for chunk in response.iter_lines(
+ chunk_size=8192, decode_unicode=False, delimiter=b"\n"
+ ):
+ if chunk:
+ data = json.loads(chunk.decode("utf-8"))
+ output = data["text"]
+ yield output
+
+
+def get_response(response: requests.Response) -> list[str]:
+ data = json.loads(response.content)
+ output = data["text"]
+ return output
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
+ parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
+ parser.add_argument("--stream", action="store_true")
+ return parser.parse_args()
+
+
+def main(args: Namespace):
+ prompt = args.prompt
+ image_path = args.image_path
+ api_url = f"http://{args.host}:{args.port}/generate"
+ stream = args.stream
+
+ print(f"Prompt: {prompt!r}\n", flush=True)
+ response = post_http_request(prompt, image_path, api_url, stream)
+
+ if stream:
+ num_printed_lines = 0
+ for h in get_streaming_response(response):
+ clear_line(num_printed_lines)
+ num_printed_lines = 0
+ for i, line in enumerate(h):
+ num_printed_lines += 1
+ print(f"Response {i}: {line!r}", flush=True)
+ else:
+ output = get_response(response)
+ print(f"Response: {output!r}", flush=True)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/deployment/tensorrt_llm/api_server.py b/deployment/tensorrt_llm/api_server.py
new file mode 100644
index 0000000..c6e8a93
--- /dev/null
+++ b/deployment/tensorrt_llm/api_server.py
@@ -0,0 +1,112 @@
+# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
+
+#!/usr/bin/env python
+import asyncio
+import base64
+import io
+import logging
+import signal
+from http import HTTPStatus
+from PIL import Image
+from typing import Optional
+
+import click
+import uvicorn
+from fastapi import FastAPI, Request
+from fastapi.responses import JSONResponse, Response
+
+from tensorrt_llm.executor import CppExecutorError, RequestError
+from dolphin_runner import DolphinRunner, InferenceConfig
+
+TIMEOUT_KEEP_ALIVE = 5 # seconds.
+
+
+async def decode_image(image_base64: str) -> Image.Image:
+ image_data = base64.b64decode(image_base64)
+ image = Image.open(io.BytesIO(image_data))
+ return image
+
+
+class LlmServer:
+ def __init__(self, runner: DolphinRunner):
+ self.runner = runner
+ self.app = FastAPI()
+ self.register_routes()
+
+ def register_routes(self):
+ self.app.add_api_route("/health", self.health, methods=["GET"])
+ self.app.add_api_route("/generate", self.generate, methods=["POST"])
+
+ async def health(self) -> Response:
+ return Response(status_code=200)
+
+ async def generate(self, request: Request) -> Response:
+ """ Generate completion for the request.
+
+ The request should be a JSON object with the following fields:
+ - prompt: the prompt to use for the generation.
+ - image_base64: the image to use for the generation.
+ """
+ request_dict = await request.json()
+
+ prompt = request_dict.pop("prompt", "")
+ logging.info(f"request prompt: {prompt}")
+ image_base64 = request_dict.pop("image_base64", "")
+ image = await decode_image(image_base64)
+
+ try:
+ output_texts = self.runner.run([prompt], [image], 4024)
+ output_texts = [texts[0] for texts in output_texts]
+ return JSONResponse({"text": output_texts[0]})
+ except RequestError as e:
+ return JSONResponse(content=str(e),
+ status_code=HTTPStatus.BAD_REQUEST)
+ except CppExecutorError:
+ # If internal executor error is raised, shutdown the server
+ signal.raise_signal(signal.SIGINT)
+
+ async def __call__(self, host, port):
+ config = uvicorn.Config(self.app,
+ host=host,
+ port=port,
+ log_level="info",
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
+ await uvicorn.Server(config).serve()
+
+
+@click.command()
+@click.option("--hf_model_dir", type=str, required=True)
+@click.option("--visual_engine_dir", type=str, required=True)
+@click.option("--llm_engine_dir", type=str, required=True)
+@click.option("--max_batch_size", type=int, default=16)
+@click.option("--max_new_tokens", type=int, default=4024)
+@click.option("--host", type=str, default=None)
+@click.option("--port", type=int, default=8000)
+def entrypoint(hf_model_dir: str,
+ visual_engine_dir: str,
+ llm_engine_dir: str,
+ max_batch_size: int,
+ max_new_tokens: int,
+ host: Optional[str] = None,
+ port: int = 8000):
+ host = host or "0.0.0.0"
+ port = port or 8000
+ logging.info(f"Starting server at {host}:{port}")
+
+ config = InferenceConfig(
+ max_new_tokens=max_new_tokens,
+ batch_size=max_batch_size,
+ log_level="info",
+ hf_model_dir=hf_model_dir,
+ visual_engine_dir=visual_engine_dir,
+ llm_engine_dir=llm_engine_dir,
+ )
+
+ dolphin_runner = DolphinRunner(config)
+ server = LlmServer(runner=dolphin_runner)
+
+ asyncio.run(server(host, port))
+
+
+if __name__ == "__main__":
+ entrypoint()
\ No newline at end of file
diff --git a/deployment/tensorrt_llm/convert/__init__.py b/deployment/tensorrt_llm/convert/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/deployment/tensorrt_llm/convert/build_visual_engine.py b/deployment/tensorrt_llm/convert/build_visual_engine.py
new file mode 100644
index 0000000..2c9ab63
--- /dev/null
+++ b/deployment/tensorrt_llm/convert/build_visual_engine.py
@@ -0,0 +1,14 @@
+# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.2/examples/multimodal/build_visual_engine.py
+
+import argparse
+
+from tensorrt_llm.tools.multimodal_builder import (VisionEngineBuilder,
+ add_multimodal_arguments)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser = add_multimodal_arguments(parser)
+ args = parser.parse_args()
+
+ builder = VisionEngineBuilder(args)
+ builder.build()
diff --git a/deployment/tensorrt_llm/convert/convert_checkpoint.py b/deployment/tensorrt_llm/convert/convert_checkpoint.py
new file mode 100644
index 0000000..c176d68
--- /dev/null
+++ b/deployment/tensorrt_llm/convert/convert_checkpoint.py
@@ -0,0 +1,1528 @@
+# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/convert_checkpoint.py
+
+import argparse
+import configparser
+import copy
+import json
+import logging
+import os
+import types
+from ast import literal_eval
+from datetime import datetime
+from pathlib import Path
+
+import safetensors
+from helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split
+from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
+ MBartForConditionalGeneration,
+ Pix2StructForConditionalGeneration,
+ T5ForConditionalGeneration, VisionEncoderDecoderModel)
+
+from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
+ MLPType)
+from tensorrt_llm.models import PretrainedConfig
+
+dir_path = os.path.dirname(os.path.realpath(__file__))
+LOGGER = logging.getLogger(__name__)
+
+layernorm_type_map = {i.name: i.value for i in LayerNormType}
+layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
+mlp_type_map = {i.name: i.value for i in MLPType}
+
+
+def copy_args_to_component_config(component_config, args):
+ for arg in vars(args):
+ setattr(component_config, arg, getattr(args, arg))
+ return component_config
+
+
+def parse_t5_config(args, hf_model):
+ config = configparser.ConfigParser()
+
+ config["encoder"] = {}
+ for key, val in hf_model.encoder.config.to_dict().items():
+ config["encoder"][key] = f"{val}"
+
+ # manually set q_scaling to offset attention scaling's effect.
+ # TODO: modify kernels to control whether to disable attention scaling
+ def get_offset_q_scaling(config):
+ scaling = 1 / config.head_size**.5
+ return scaling
+
+ config["decoder"] = {}
+ for key, val in hf_model.decoder.config.to_dict().items():
+ config["decoder"][key] = f"{val}"
+
+ config["structure"] = dict()
+ config["structure"]["t5_with_bias"] = "false"
+ config["structure"]["use_gated_activation"] = str(
+ hf_model.encoder.config.is_gated_act)
+ config["structure"]["position_embedding_type"] = "relative"
+ config["structure"]["model_type"] = args.model_type
+
+ def parse_t5_config_by_component(config, component, args):
+ component_config = types.SimpleNamespace()
+ component_config = copy_args_to_component_config(component_config, args)
+ component_config.n_head = config.getint(component, 'num_heads')
+ component_config.head_size = config.getint(component, 'd_kv')
+ component_config.hidden_size = config.getint(component, 'd_model')
+ component_config.ffn_hidden_size = config.getint(component, 'd_ff')
+ component_config.vocab_size = config.getint(component, 'vocab_size')
+ component_config.n_positions = config.getint(component,
+ 'n_positions',
+ fallback=512)
+ component_config.has_position_embedding = config.getboolean(
+ component, 'has_position_embedding',
+ fallback=False) # TODO: hardcoded here
+
+ component_config.has_token_type_embedding = config.getboolean(
+ component, 'has_token_type_embedding', fallback=False)
+ component_config.has_embedding_layernorm = config.getboolean(
+ component, 'has_embedding_layernorm', fallback=False)
+ component_config.has_embedding_scale = config.getboolean(
+ component, 'has_embedding_scale', fallback=False)
+ component_config.q_scaling = get_offset_q_scaling(component_config)
+ component_config.has_attention_qkvo_bias = config.getboolean(
+ component, 'has_attention_qkvo_bias',
+ fallback=False) # TODO: hardcoded here
+ component_config.has_mlp_bias = config.getboolean(component,
+ 'has_mlp_bias',
+ fallback=False)
+ component_config.has_model_final_layernorm = config.getboolean(
+ component, 'has_model_final_layernorm', fallback=True)
+ component_config.layernorm_eps = config.getfloat(
+ component, 'layer_norm_epsilon')
+ component_config.layernorm_position = layernorm_position_map[config.get(
+ component, 'layernorm_position',
+ fallback='pre_layernorm')] # TODO: hardcoded here
+ component_config.layernorm_type = layernorm_type_map[config.get(
+ component, 'layernorm_type', fallback='RmsNorm')]
+ component_config.hidden_act = config.get(component, 'dense_act_fn')
+ component_config.gated_act = config.getboolean(component,
+ 'is_gated_act')
+ component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
+ gated_act else 'MLP']
+ component_config.num_buckets = config.getint(
+ component, 'relative_attention_num_buckets')
+ component_config.max_distance = config.getint(
+ component, 'relative_attention_max_distance')
+ component_config.position_embedding_type = config.get(
+ 'structure', 'position_embedding_type')
+ component_config.logits_dtype = config.get(component,
+ 'logits_dtype',
+ fallback='float32')
+
+ if component == 'encoder':
+ component_config.n_layer = config.getint(component, 'num_layers')
+
+ component_config.relative_attention = config.get(
+ 'structure', 'position_embedding_type') == 'relative'
+
+ elif component == 'decoder':
+ component_config.n_layer = config.getint(component,
+ 'num_decoder_layers')
+ component_config.has_lm_head_bias = config.getboolean(
+ component, # TODO: T5 with bias
+ 'has_lm_head_bias',
+ fallback=False)
+ component_config.relative_attention = config.getboolean(
+ component, 'relative_attention', fallback=True)
+ component_config.rescale_before_lm_head = config.getboolean(
+ component, 'tie_word_embeddings'
+ ) # default is True (for T5), but False for Flan-T5
+ component_config.encoder_hidden_size = config.getint(
+ 'encoder', 'd_model')
+ component_config.encoder_num_heads = config.getint(
+ 'encoder', 'num_heads')
+ component_config.encoder_head_size = config.getint(
+ 'encoder', 'd_kv')
+ component_config.decoder_start_token_id = config.getint(
+ 'decoder', 'decoder_start_token_id')
+ component_config.eos_token_id = config.getint(
+ 'decoder', 'eos_token_id')
+ bos_token_id = config.get('decoder', 'bos_token_id')
+ # T5 does not have bos_token_id
+ component_config.bos_token_id = int(
+ bos_token_id) if bos_token_id != "None" else None
+ component_config.pad_token_id = config.getint(
+ 'decoder', 'pad_token_id')
+
+ else:
+ assert False, 'Unsupported component!'
+
+ return component_config
+
+ encoder_config = parse_t5_config_by_component(config, "encoder", args)
+ decoder_config = parse_t5_config_by_component(config, "decoder", args)
+
+ return encoder_config, decoder_config
+
+
+def convert_t5_weights_to_tllm_safetensors(config, component, params):
+ weights = {}
+
+ mapping = config.mapping
+
+ convert_weight_to_dtype(params, config.dtype)
+ hidden_size = config.hidden_size
+ ffn_hidden_size = config.intermediate_size
+ num_layers = config.num_hidden_layers
+ n_head = config.num_attention_heads
+ head_size = config.head_size
+ attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5
+
+ hf_param_prefix = f'{component}'
+ trtllm_layer_name = f'{component}_layers'
+ trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
+ trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
+ hf_component_idx = 1 if component == 'encoder' else 2
+
+ def get_attn_module_name(component, block, layer, attn_type):
+ return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}'
+
+ weights['embedding.vocab_embedding.weight'] = reshape(
+ params['shared.weight'].clone(), None)
+
+ layers_range = mapping.pp_layers(num_layers)
+ for layer_idx in layers_range:
+ local_layer_idx = layer_idx - layers_range[0]
+ trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
+ hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}'
+
+ hidden_layer_name_split = {
+ f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': {
+ "name":
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
+ "shape":
+ (hidden_size, attention_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight':
+ {
+ "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
+ "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight':
+ {
+ "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight':
+ {
+ "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ }
+
+ hidden_layer_name_no_split = {
+ f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': {
+ "name":
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
+ "shape": None
+ },
+ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight':
+ {
+ "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
+ "shape": None
+ },
+ }
+
+ if config.gated_act:
+ hidden_layer_name_split.update({
+ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight':
+ {
+ "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight':
+ {
+ "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ })
+
+ if component == 'decoder':
+ hidden_layer_name_split.update({
+ f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': {
+ "name":
+ f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
+ "shape":
+ (hidden_size, attention_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ })
+ hidden_layer_name_no_split.update({
+ f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': {
+ "name":
+ f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
+ "shape": None
+ },
+ })
+ self_attn_module_name = get_attn_module_name(
+ component, layer_idx, "1", 'EncDecAttention')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, self_attn_module_name,
+ f'{trtllm_layer_name_prefix}.cross_attention',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
+ None))
+
+ self_attn_module_name = get_attn_module_name(component, layer_idx, "0",
+ 'SelfAttention')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, self_attn_module_name,
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
+ None))
+
+ weights[
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
+ split(
+ params[
+ f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
+ .T, mapping.tp_size, mapping.tp_rank, 0),
+ (n_head // mapping.tp_size, config.num_buckets))
+
+ for hf_weight_name, weight_info in hidden_layer_name_split.items():
+ if hf_weight_name in params.keys():
+ weights[weight_info["name"]] = reshape(
+ split(params[hf_weight_name],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=weight_info["split_dim"]), weight_info["shape"])
+ for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
+ if hf_weight_name in params.keys():
+ weights[weight_info["name"]] = reshape(
+ params[hf_weight_name].clone(), shape=weight_info["shape"])
+
+ weights['final_layernorm.weight'] = reshape(
+ params[f'{component}.final_layer_norm.weight'].clone(), None)
+
+ if component == 'decoder':
+ weights['lm_head.weight'] = reshape(
+ split(params['lm_head.weight'],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
+ if not config.use_implicit_relative_attention:
+ weights['rel_attn_table'] = reshape(
+ split(
+ params[
+ f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
+ .T, mapping.tp_size, mapping.tp_rank, 0),
+ (n_head // mapping.tp_size, config.num_buckets))
+
+ return weights
+
+
+convert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors # func alias
+
+
+def parse_nmt_config(args, model):
+ config = configparser.ConfigParser()
+ fairseq_config = vars(model.cfg.model) # Namespace --> dict
+
+ config['encoder'] = dict()
+ for key, val in fairseq_config.items():
+ config["encoder"][key] = f"{val}"
+ config["encoder"]["q_scaling"] = '1'
+ # NMT has final layernorm for pre-norm model architecture.
+ config['encoder']['has_model_final_layernorm'] = config['encoder'][
+ 'encoder_normalize_before']
+ config['encoder']['vocab_size'] = str(len(model.src_dict)) # fairseq naming
+
+ config['decoder'] = dict()
+ for key, val in fairseq_config.items():
+ config["decoder"][key] = f"{val}"
+ config["decoder"]["q_scaling"] = '1'
+ config["decoder"]["rescale_before_lm_head"] = 'false'
+ config['decoder']['has_model_final_layernorm'] = str(
+ config['decoder'].getboolean('decoder_normalize_before', False)
+ and not config['decoder'].getboolean('no_decoder_final_norm', False))
+ config['decoder']['vocab_size'] = str(len(model.tgt_dict)) # fairseq naming
+
+ config["structure"] = dict()
+ config["structure"]["t5_with_bias"] = "true"
+ config["structure"]["use_gated_activation"] = "false"
+ config["structure"][
+ "position_embedding_type"] = "learned_absolute" # "sinusoid"
+ config["structure"]["model_type"] = args.model_type
+
+ def parse_nmt_config_by_component(config, component, args):
+ assert component in ('encoder', 'decoder'), 'Unsupported component!'
+ component_config = types.SimpleNamespace()
+ component_config = copy_args_to_component_config(component_config, args)
+ component_config.n_layer = config.getint(component,
+ f'{component}_layers')
+ component_config.n_head = config.getint(component,
+ f'{component}_attention_heads')
+ component_config.hidden_size = config.getint(
+ component, f'{component}_embed_dim') # fairseq naming
+ component_config.head_size = config.getint(
+ component,
+ 'd_kv',
+ fallback=component_config.hidden_size // component_config.n_head)
+ component_config.ffn_hidden_size = config.getint(
+ component, f'{component}_ffn_embed_dim') # fairseq naming
+ component_config.vocab_size = config.getint(component, 'vocab_size')
+ component_config.n_positions = config.getint(
+ component, 'max_source_positions') # fairseq naming
+ component_config.has_position_embedding = not config.getboolean(
+ component, 'no_token_positional_embeddings',
+ fallback=False) # fairseq naming
+ component_config.has_token_type_embedding = config.getboolean(
+ component, 'has_token_type_embedding', fallback=False)
+ component_config.has_embedding_layernorm = config.getboolean(
+ component, 'layernorm_embedding', fallback=True) # fairseq naming
+ component_config.has_embedding_scale = not config.getboolean(
+ component, 'no_scale_embedding') # fairseq naming
+ component_config.q_scaling = config.getfloat(component,
+ 'q_scaling',
+ fallback=1.0)
+ component_config.has_attention_qkvo_bias = config.getboolean(
+ 'structure', 't5_with_bias', fallback=True)
+ component_config.has_mlp_bias = config.getboolean('structure',
+ 't5_with_bias',
+ fallback=True)
+ component_config.has_model_final_layernorm = config.getboolean(
+ component, 'has_model_final_layernorm')
+ component_config.layernorm_eps = config.getfloat(
+ component, 'layer_norm_epsilon', fallback=1e-5) # fairseq naming
+
+ normalize_before = config.getboolean(
+ component, f'{component}_normalize_before') # fairseq naming
+ component_config.layernorm_position = layernorm_position_map[
+ 'pre_layernorm' if normalize_before else 'post_layernorm']
+
+ component_config.layernorm_type = layernorm_type_map[config.get(
+ component, 'layernorm_type', fallback='LayerNorm')]
+ component_config.hidden_act = config.get(
+ component, 'activation_fn') # fairseq naming
+ component_config.gated_act = config.getboolean(component,
+ 'is_gated_act',
+ fallback=False)
+ component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
+ gated_act else 'MLP']
+ component_config.relative_attention = config.get(
+ 'structure', 'position_embedding_type') == 'relative'
+
+ component_config.num_buckets = config.getint(
+ component, 'relative_attention_num_buckets', fallback=0)
+ component_config.max_distance = config.getint(
+ component, 'relative_attention_max_distance', fallback=0)
+ component_config.position_embedding_type = config.get(
+ 'structure', 'position_embedding_type')
+ component_config.logits_dtype = config.get(component,
+ 'logits_dtype',
+ fallback='float32')
+ if component == 'decoder':
+ component_config.rescale_before_lm_head = config.getboolean(
+ component, 'rescale_before_lm_head')
+
+ component_config.encoder_hidden_size = config.getint(
+ 'encoder', 'encoder_embed_dim') # fairseq naming
+ component_config.encoder_num_heads = config.getint(
+ 'encoder', 'encoder_attention_heads')
+ component_config.encoder_head_size = config.getint(
+ 'encoder',
+ 'd_kv',
+ fallback=component_config.encoder_hidden_size //
+ component_config.encoder_num_heads)
+ component_config.decoder_start_token_id = None
+ component_config.eos_token_id = None
+ component_config.bos_token_id = None
+ component_config.pad_token_id = None
+
+ return component_config
+
+ encoder_config = parse_nmt_config_by_component(config, "encoder", args)
+ decoder_config = parse_nmt_config_by_component(config, "decoder", args)
+
+ return encoder_config, decoder_config
+
+
+def convert_nmt_weights_to_tllm_safetensors(config, component, params,
+ sin_pos_embedding):
+ weights = {}
+
+ mapping = config.mapping
+
+ hidden_size = config.hidden_size
+
+ convert_weight_to_dtype(params, config.dtype)
+ ffn_hidden_size = config.intermediate_size
+ vocab_size = config.vocab_size
+
+ hf_param_prefix = f'models.0.{component}'
+ trtllm_layer_name = f'{component}_layers'
+ trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
+ trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
+
+ hidden_layer_name_split = {
+ 'self_attn.out_proj.weight': {
+ "name": f'{trtllm_attn_layer_name}.dense.weight',
+ "shape": (hidden_size, hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ 'fc1.weight': {
+ "name": 'mlp.fc.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ 'fc1.bias': {
+ "name": 'mlp.fc.bias',
+ "shape": (ffn_hidden_size // mapping.tp_size),
+ "split_dim": 0
+ },
+ 'fc2.weight': {
+ "name": 'mlp.proj.weight',
+ "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ }
+
+ hidden_layer_name_no_split = {
+ 'self_attn.out_proj.bias': {
+ "name": f'{trtllm_attn_layer_name}.dense.bias',
+ "shape": (hidden_size)
+ },
+ 'self_attn_layer_norm.weight': {
+ "name": f'{trtllm_attn_layernorm_name}.weight',
+ "shape": None
+ },
+ 'self_attn_layer_norm.bias': {
+ "name": f'{trtllm_attn_layernorm_name}.bias',
+ "shape": None
+ },
+ 'fc2.bias': {
+ "name": 'mlp.proj.bias',
+ "shape": (hidden_size)
+ },
+ 'final_layer_norm.weight': {
+ "name": 'mlp_layernorm.weight',
+ "shape": None
+ },
+ 'final_layer_norm.bias': {
+ "name": 'mlp_layernorm.bias',
+ "shape": None
+ },
+ }
+
+ if component == "decoder":
+ hidden_layer_name_split.update({
+ 'encoder_attn.out_proj.weight': {
+ "name": 'cross_attention.dense.weight',
+ "shape": (hidden_size, hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ })
+ hidden_layer_name_no_split.update({
+ 'encoder_attn.out_proj.bias': {
+ "name": 'cross_attention.dense.bias',
+ "shape": (hidden_size)
+ },
+ 'encoder_attn_layer_norm.weight': {
+ "name": 'cross_attention_layernorm.weight',
+ "shape": None,
+ },
+ 'encoder_attn_layer_norm.bias': {
+ "name": 'cross_attention_layernorm.bias',
+ "shape": None
+ },
+ })
+
+ def get_attn_module_name(component, layer, attn_type):
+ return f'models.0.{component}.layers.{int(layer)}.{attn_type}'
+
+ weights["embedding.vocab_embedding.weight"] = reshape(
+ params[f'{hf_param_prefix}.embed_tokens.weight'].clone(),
+ (vocab_size, -1))
+ weights["embedding.position_embedding.weight"] = reshape(
+ sin_pos_embedding, (config.max_position_embeddings, hidden_size))
+
+ num_layers = config.num_hidden_layers
+
+ layers_range = mapping.pp_layers(num_layers)
+ for layer_idx in layers_range:
+ local_layer_idx = layer_idx - layers_range[0]
+ hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
+ trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
+
+ for hf_weight_name, weight_info in hidden_layer_name_split.items():
+ weights[
+ f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
+ split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=weight_info["split_dim"]), weight_info["shape"])
+
+ for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
+ trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
+ hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
+ weights[trtllm_layer_fullname] = reshape(
+ params[hf_layer_fullname].clone(), shape=weight_info["shape"])
+
+ self_attn_module_name = get_attn_module_name(component, layer_idx,
+ 'self_attn')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, self_attn_module_name,
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (hidden_size * 3 // mapping.tp_size, hidden_size),
+ (hidden_size * 3 // mapping.tp_size)))
+ if component == 'decoder':
+ cross_attn_module_name = get_attn_module_name(
+ component, layer_idx, 'encoder_attn')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, cross_attn_module_name,
+ f'{trtllm_layer_name_prefix}.cross_attention',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (hidden_size * 3 // mapping.tp_size, hidden_size),
+ (hidden_size * 3 // mapping.tp_size)))
+
+ if component == 'decoder':
+ weights['lm_head.weight'] = reshape(
+ split(params[f'{hf_param_prefix}.output_projection.weight'],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
+
+ if config.has_model_final_layernorm:
+ weights['final_layernorm.weight'] = params[
+ f'{hf_param_prefix}.layer_norm.weight'].clone()
+ weights['final_layernorm.bias'] = params[
+ f'{hf_param_prefix}.layer_norm.bias'].clone()
+
+ return weights
+
+
+def parse_bart_config(args, hf_model):
+
+ config = configparser.ConfigParser()
+
+ config['decoder'] = dict()
+ for key, val in hf_model.model.decoder.config.to_dict().items():
+ config["decoder"][key] = f"{val}"
+ config["decoder"]["q_scaling"] = '1'
+ config["decoder"]["rescale_before_lm_head"] = str(False)
+ config['decoder']['has_model_final_layernorm'] = str(
+ args.nougat or isinstance(hf_model, MBartForConditionalGeneration))
+
+ if args.nougat:
+ # These flags are true for mbart decoders, but missing in HF config
+ config['decoder']['normalize_before'] = str(True)
+ config['decoder']['normalize_embeddings'] = str(True)
+
+ config['encoder'] = dict()
+ # Init few encoder configs, needed by build, from decoder config
+ encoder_config_keys = [
+ "encoder_ffn_dim", "encoder_layers", "encoder_attention_heads",
+ "encoder_layerdrop", "d_model"
+ ]
+ for key in encoder_config_keys:
+ config['encoder'][key] = config['decoder'][key]
+ else:
+ config['encoder'] = dict()
+ for key, val in hf_model.model.encoder.config.to_dict().items():
+ config["encoder"][key] = f"{val}"
+ config["encoder"]["q_scaling"] = '1'
+
+ # mBART has final layernorm, BART does not
+ config['encoder']['has_model_final_layernorm'] = str(
+ isinstance(hf_model, MBartForConditionalGeneration))
+
+ config["structure"] = dict()
+ config["structure"]["t5_with_bias"] = "true"
+ config["structure"]["use_gated_activation"] = "false"
+ config["structure"]["position_embedding_type"] = "learned_absolute"
+ config["structure"]["model_type"] = args.model_type
+
+ def parse_bart_config_by_component(config, component, args):
+ assert component in ('encoder', 'decoder'), 'Unsupported component!'
+ component_config = types.SimpleNamespace()
+ component_config = copy_args_to_component_config(component_config, args)
+ component_config.n_layer = config.getint(component,
+ f'{component}_layers')
+ component_config.n_head = config.getint(component,
+ f'{component}_attention_heads')
+ component_config.hidden_size = config.getint(component, 'd_model')
+ component_config.head_size = config.getint(
+ component,
+ 'd_kv',
+ fallback=component_config.hidden_size // component_config.n_head)
+ component_config.ffn_hidden_size = config.getint(
+ component, f'{component}_ffn_dim')
+ component_config.vocab_size = config.getint(component, 'vocab_size')
+ component_config.n_positions = config.getint(component,
+ 'max_position_embeddings')
+ component_config.has_position_embedding = config.getboolean(
+ component, 'has_position_embedding',
+ fallback=True) # TODO: hardcoded here
+ component_config.has_token_type_embedding = config.getboolean(
+ component, 'has_token_type_embedding', fallback=False)
+ component_config.has_embedding_layernorm = config.getboolean(
+ component, 'has_embedding_layernorm', fallback=True)
+ component_config.has_embedding_scale = config.getboolean(
+ component, 'scale_embedding')
+ component_config.q_scaling = config.getfloat(component,
+ 'q_scaling',
+ fallback=1.0)
+ component_config.has_attention_qkvo_bias = config.getboolean(
+ 'structure', 't5_with_bias', fallback=True)
+ component_config.has_mlp_bias = config.getboolean('structure',
+ 't5_with_bias',
+ fallback=True)
+ component_config.has_model_final_layernorm = config.getboolean(
+ component, 'has_model_final_layernorm')
+ component_config.layernorm_eps = config.getfloat(component,
+ 'layer_norm_epsilon',
+ fallback=False)
+
+ normalize_before = config.getboolean(component, 'normalize_before')
+ component_config.layernorm_position = layernorm_position_map[
+ 'pre_layernorm' if normalize_before else 'post_layernorm']
+
+ component_config.layernorm_type = layernorm_type_map[config.get(
+ component, 'layernorm_type', fallback='LayerNorm')]
+ component_config.hidden_act = config.get(component,
+ 'activation_function')
+ component_config.gated_act = config.getboolean(component,
+ 'is_gated_act',
+ fallback=False)
+ component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
+ gated_act else 'MLP']
+ component_config.relative_attention = config.get(
+ 'structure', 'position_embedding_type') == 'relative'
+
+ component_config.num_buckets = config.getint(
+ component, 'relative_attention_num_buckets', fallback=0)
+ component_config.max_distance = config.getint(
+ component, 'relative_attention_max_distance', fallback=0)
+ component_config.max_lora_rank = config.getint(component,
+ 'max_lora_rank',
+ fallback=0)
+ component_config.lora_target_modules = literal_eval(
+ config.get(component, 'lora_target_modules', fallback="[]"))
+ component_config.hf_modules_to_trtllm_modules = literal_eval(
+ config.get(component, 'hf_modules_to_trtllm_modules',
+ fallback="{}"))
+ component_config.trtllm_modules_to_hf_modules = literal_eval(
+ config.get(component, 'trtllm_modules_to_hf_modules',
+ fallback="{}"))
+ component_config.logits_dtype = config.get(component,
+ 'logits_dtype',
+ fallback='float32')
+ component_config.position_embedding_type = config.get(
+ 'structure', 'position_embedding_type')
+
+ if component == 'decoder':
+ component_config.rescale_before_lm_head = config.getboolean(
+ component, 'rescale_before_lm_head')
+
+ component_config.encoder_hidden_size = config.getint(
+ 'encoder', 'd_model')
+ component_config.encoder_num_heads = config.getint(
+ 'encoder', 'encoder_attention_heads')
+ component_config.encoder_head_size = config.getint(
+ 'encoder',
+ 'd_kv',
+ fallback=component_config.encoder_hidden_size //
+ component_config.encoder_num_heads)
+
+ # nougat has decoder_start_token_id = None, special handling
+ decoder_start_token_id = config.get('decoder',
+ 'decoder_start_token_id')
+ component_config.decoder_start_token_id = int(
+ decoder_start_token_id
+ ) if decoder_start_token_id != "None" else None
+ component_config.eos_token_id = config.getint(
+ 'decoder', 'eos_token_id')
+ component_config.bos_token_id = config.getint(
+ 'decoder', 'bos_token_id')
+ component_config.pad_token_id = config.getint(
+ 'decoder', 'pad_token_id')
+
+ return component_config
+
+ encoder_config = None
+ if not args.nougat:
+ encoder_config = parse_bart_config_by_component(config, "encoder", args)
+ decoder_config = parse_bart_config_by_component(config, "decoder", args)
+
+ return encoder_config, decoder_config
+
+
+def convert_bart_weights_to_tllm_safetensors(config, component, params):
+ weights = {}
+
+ mapping = config.mapping
+
+ hidden_size = config.hidden_size
+
+ convert_weight_to_dtype(params, config.dtype)
+ ffn_hidden_size = config.intermediate_size
+ vocab_size = config.vocab_size
+
+ hf_param_prefix = f'model.{component}'
+ trtllm_layer_name = f'{component}_layers'
+ trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
+ trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
+ embedding_layer_names = {
+ 'embed_tokens.weight': {
+ "name": 'embedding.vocab_embedding.weight',
+ "shape": (vocab_size, -1)
+ },
+ 'embed_positions.weight': {
+ "name": 'embedding.position_embedding.weight',
+ "shape": (config.max_position_embeddings, hidden_size)
+ },
+ 'layernorm_embedding.weight': {
+ "name": 'embedding.embedding_layernorm.weight',
+ "shape": None
+ },
+ 'layernorm_embedding.bias': {
+ "name": 'embedding.embedding_layernorm.bias',
+ "shape": None
+ },
+ }
+
+ hidden_layer_name_split = {
+ 'self_attn.out_proj.weight': {
+ "name": f'{trtllm_attn_layer_name}.dense.weight',
+ "shape": (hidden_size, hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ 'fc1.weight': {
+ "name": 'mlp.fc.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ 'fc1.bias': {
+ "name": 'mlp.fc.bias',
+ "shape": (ffn_hidden_size // mapping.tp_size),
+ "split_dim": 0
+ },
+ 'fc2.weight': {
+ "name": 'mlp.proj.weight',
+ "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ }
+
+ hidden_layer_name_no_split = {
+ 'self_attn.out_proj.bias': {
+ "name": f'{trtllm_attn_layer_name}.dense.bias',
+ "shape": (hidden_size)
+ },
+ 'self_attn_layer_norm.weight': {
+ "name": f'{trtllm_attn_layernorm_name}.weight',
+ "shape": None
+ },
+ 'self_attn_layer_norm.bias': {
+ "name": f'{trtllm_attn_layernorm_name}.bias',
+ "shape": None
+ },
+ 'fc2.bias': {
+ "name": 'mlp.proj.bias',
+ "shape": (hidden_size)
+ },
+ 'final_layer_norm.weight': {
+ "name": 'mlp_layernorm.weight',
+ "shape": None
+ },
+ 'final_layer_norm.bias': {
+ "name": 'mlp_layernorm.bias',
+ "shape": None
+ },
+ }
+
+ if config.model_type == 'mbart':
+ hidden_layer_name_split['layer_norm.weight'] = {
+ "name": 'final_layernorm.weight',
+ "shape": None,
+ "split_dim": 0
+ }
+ hidden_layer_name_no_split['layer_norm.bias'] = {
+ "name": 'final_layernorm.bias',
+ "shape": None,
+ "split_dim": 0
+ }
+
+ if component == "decoder":
+ hidden_layer_name_split.update({
+ 'encoder_attn.out_proj.weight': {
+ "name": 'cross_attention.dense.weight',
+ "shape": (hidden_size, hidden_size // mapping.tp_size),
+ "split_dim": -1
+ }
+ })
+ hidden_layer_name_no_split.update({
+ 'encoder_attn.out_proj.bias': {
+ "name": 'cross_attention.dense.bias',
+ "shape": (hidden_size)
+ },
+ 'encoder_attn_layer_norm.weight': {
+ "name": 'cross_attention_layernorm.weight',
+ "shape": None
+ },
+ 'encoder_attn_layer_norm.bias': {
+ "name": 'cross_attention_layernorm.bias',
+ "shape": None
+ },
+ })
+
+ def get_attn_module_name(component, layer, attn_type):
+ return f'model.{component}.layers.{int(layer)}.{attn_type}'
+
+ for hf_weight_name, weight_info in embedding_layer_names.items():
+ if 'position' in hf_weight_name:
+ weights[weight_info["name"]] = params[
+ f'{hf_param_prefix}.{hf_weight_name}'][2:].clone()
+ else:
+ weights[weight_info["name"]] = params[
+ f'{hf_param_prefix}.{hf_weight_name}'].clone()
+ weights[weight_info["name"]] = reshape(weights[weight_info["name"]],
+ weight_info["shape"])
+
+ num_layers = config.num_hidden_layers
+
+ layers_range = mapping.pp_layers(num_layers)
+ for layer_idx in layers_range:
+ local_layer_idx = layer_idx - layers_range[0]
+ hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
+ trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
+
+ for hf_weight_name, weight_info in hidden_layer_name_split.items():
+ weights[
+ f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
+ split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=weight_info["split_dim"]), weight_info["shape"])
+
+ for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
+ trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
+ hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
+ weights[trtllm_layer_fullname] = reshape(
+ params[hf_layer_fullname].clone(), shape=weight_info["shape"])
+
+ self_attn_module_name = get_attn_module_name(component, layer_idx,
+ 'self_attn')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, self_attn_module_name,
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (hidden_size * 3 // mapping.tp_size, hidden_size),
+ (hidden_size * 3 // mapping.tp_size)))
+ if component == 'decoder':
+ cross_attn_module_name = get_attn_module_name(
+ component, layer_idx, 'encoder_attn')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, cross_attn_module_name,
+ f'{trtllm_layer_name_prefix}.cross_attention',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (hidden_size * 3 // mapping.tp_size, hidden_size),
+ (hidden_size * 3 // mapping.tp_size)))
+
+ if component == 'decoder':
+ weights['lm_head.weight'] = reshape(
+ split(params['lm_head.weight'],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
+
+ if config.has_model_final_layernorm:
+ weights['final_layernorm.weight'] = params[
+ f'{hf_param_prefix}.layer_norm.weight'].clone()
+ weights['final_layernorm.bias'] = params[
+ f'{hf_param_prefix}.layer_norm.bias'].clone()
+
+ return weights
+
+
+def parse_pix2struct_config(args, hf_model):
+ # manually set q_scaling to offset attention scaling's effect.
+ # TODO: modify kernels to control whether to disable attention scaling
+ config = configparser.ConfigParser()
+
+ def get_offset_q_scaling(config) -> str:
+ d_model = config.hidden_size
+ num_heads = config.num_heads
+ head_size = d_model / num_heads
+ scaling = 1 / head_size**.5
+ return str(scaling)
+
+ config["decoder"] = {}
+ for key, val in hf_model.decoder.config.to_dict().items():
+ config["decoder"][key] = f"{val}"
+
+ config["decoder"]["q_scaling"] = get_offset_q_scaling(
+ hf_model.decoder.config)
+
+ config["structure"] = dict()
+ config["structure"]["pix2struct_with_bias"] = "false"
+ config["structure"]["use_gated_activation"] = "false"
+ config["structure"]["position_embedding_type"] = "relative"
+ config["structure"]["model_type"] = args.model_type
+
+ def parse_pix2struct_config_by_component(config, component, args):
+ if component == 'decoder':
+ args.n_layer = config.getint(component, 'num_layers')
+ args.n_head = config.getint(component, 'num_heads')
+ args.head_size = config.getint(component, 'd_kv')
+ args.hidden_size = config.getint(component, 'hidden_size')
+ args.ffn_hidden_size = config.getint(component, 'd_ff')
+ args.vocab_size = config.getint(component, 'vocab_size')
+ args.n_positions = config.getint(component,
+ 'n_positions',
+ fallback=512)
+ args.has_position_embedding = config.getboolean(
+ component, 'has_position_embedding',
+ fallback=False) # TODO: hardcoded here
+ args.has_token_type_embedding = config.getboolean(
+ component, 'has_token_type_embedding', fallback=False)
+ args.has_embedding_layernorm = config.getboolean(
+ component, 'has_embedding_layernorm', fallback=False)
+ args.has_embedding_scale = config.getboolean(component,
+ 'has_embedding_scale',
+ fallback=False)
+ args.q_scaling = config.getfloat(component,
+ 'q_scaling',
+ fallback=1.0)
+ args.has_attention_qkvo_bias = config.getboolean(
+ component, 'has_attention_qkvo_bias', fallback=False)
+ args.has_mlp_bias = config.getboolean(component,
+ 'has_mlp_bias',
+ fallback=False)
+ args.has_model_final_layernorm = config.getboolean(
+ component, 'has_model_final_layernorm', fallback=True)
+ args.layernorm_eps = config.getfloat(component,
+ 'layer_norm_epsilon')
+ args.layernorm_position = layernorm_position_map[config.get(
+ component, 'layernorm_position',
+ fallback='pre_layernorm')] # TODO: hardcoded here
+ args.layernorm_type = layernorm_type_map[config.get(
+ component, 'layernorm_type', fallback='RmsNorm')]
+ args.hidden_act = config.get(component, 'dense_act_fn')
+ args.gated_act = True
+ args.mlp_type = mlp_type_map['GatedMLP' if args.
+ gated_act else 'MLP']
+ args.has_lm_head_bias = config.getboolean(
+ component, # TODO: T5 with bias
+ 'has_lm_head_bias',
+ fallback=False)
+ args.relative_attention = config.getboolean(component,
+ 'relative_attention',
+ fallback=True)
+ args.num_buckets = config.getint(component,
+ 'relative_attention_num_buckets')
+ args.max_distance = config.getint(
+ component, 'relative_attention_max_distance')
+ args.logits_dtype = config.get(component,
+ 'logits_dtype',
+ fallback='float32')
+ args.rescale_before_lm_head = config.getboolean(
+ component, 'tie_word_embeddings'
+ ) # default is True (for T5), but False for Flan-T5
+ args.encoder_hidden_size = config.getint('decoder', 'hidden_size')
+ args.encoder_num_heads = config.getint('decoder', 'num_heads')
+ args.encoder_head_size = config.getint('decoder', 'd_kv')
+ args.position_embedding_type = config.get(
+ 'structure', 'position_embedding_type')
+ args.decoder_start_token_id = config.getint(
+ 'decoder', 'decoder_start_token_id')
+ args.eos_token_id = config.getint('decoder', 'eos_token_id')
+ bos_token_id = config.get('decoder', 'bos_token_id')
+ # pix2struct does not have bos_token_id
+ args.bos_token_id = int(
+ bos_token_id) if bos_token_id != "None" else None
+ args.pad_token_id = config.getint('decoder', 'pad_token_id')
+
+ else:
+ assert False, 'Unsupported component!'
+ return args
+
+ decoder_args = parse_pix2struct_config_by_component(config, "decoder", args)
+ return None, decoder_args
+
+
+def convert_pix2struct_weights_to_tllm_safetensors(config, component, params):
+ weights = {}
+
+ mapping = config.mapping
+
+ convert_weight_to_dtype(params, config.dtype)
+ hidden_size = config.hidden_size
+ ffn_hidden_size = config.intermediate_size
+ num_layers = config.num_hidden_layers
+ n_head = config.num_attention_heads
+ head_size = config.head_size
+ attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5
+
+ hf_param_prefix = f'{component}'
+ trtllm_layer_name = f'{component}_layers'
+ trtllm_attn_layer_name = 'self_attention'
+ trtllm_attn_layernorm_name = 'self_attention_layernorm'
+
+ def get_attn_module_name(component, layer, attn_type):
+ return f'{component}.layer.{int(layer)}.{attn_type}.attention'
+
+ weights['embedding.vocab_embedding.weight'] = reshape(
+ params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)
+
+ layers_range = mapping.pp_layers(num_layers)
+ for layer_idx in layers_range:
+ local_layer_idx = layer_idx - layers_range[0]
+ trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
+ hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'
+
+ hidden_layer_name_split = {
+ f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {
+ "name":
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
+ "shape":
+ (hidden_size, attention_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {
+ "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
+ "shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {
+ "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ }
+
+ hidden_layer_name_no_split = {
+ f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {
+ "name":
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
+ "shape": None
+ },
+ f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {
+ "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
+ "shape": None
+ },
+ }
+
+ if config.gated_act:
+ hidden_layer_name_split.update({
+ f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {
+ "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
+ "shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
+ "split_dim": 0
+ },
+ })
+
+ hidden_layer_name_split.update({
+ f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':
+ {
+ "name":
+ f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
+ "shape":
+ (hidden_size, attention_hidden_size // mapping.tp_size),
+ "split_dim": -1
+ },
+ })
+ hidden_layer_name_no_split.update({
+ f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':
+ {
+ "name":
+ f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
+ "shape": None
+ },
+ })
+ self_attn_module_name = get_attn_module_name(
+ component, layer_idx, 'encoder_decoder_attention')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, self_attn_module_name,
+ f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,
+ mapping.tp_rank, config.model_type,
+ (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
+ None))
+
+ self_attn_module_name = get_attn_module_name(component, layer_idx,
+ 'self_attention')
+ weights.update(
+ fuse_qkv_one_layer(
+ params, self_attn_module_name,
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
+ mapping.tp_size, mapping.tp_rank, config.model_type,
+ (attention_hidden_size * 3 // mapping.tp_size, hidden_size),
+ None))
+
+ weights[
+ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
+ split(
+ params[
+ f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
+ .T, mapping.tp_size, mapping.tp_rank, 0),
+ (n_head // mapping.tp_size, config.num_buckets))
+
+ for hf_weight_name, weight_info in hidden_layer_name_split.items():
+ if hf_weight_name in params.keys():
+ weights[weight_info["name"]] = reshape(
+ split(params[hf_weight_name],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=weight_info["split_dim"]), weight_info["shape"])
+ for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
+ if hf_weight_name in params.keys():
+ weights[weight_info["name"]] = reshape(
+ params[hf_weight_name].clone(), shape=weight_info["shape"])
+
+ weights[f'final_layernorm.weight'] = reshape(
+ params[f'{component}.final_layer_norm.weight'].clone(), None)
+
+ weights['lm_head.weight'] = reshape(
+ split(params[f'{component}.lm_head.weight'],
+ mapping.tp_size,
+ mapping.tp_rank,
+ dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
+ if not config.use_implicit_relative_attention:
+ weights[f'rel_attn_table'] = reshape(
+ split(
+ params[
+ f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
+ .T, mapping.tp_size, mapping.tp_rank, 0),
+ (n_head // mapping.tp_size, config.num_buckets))
+
+ return weights
+
+
+def get_model(args):
+ if args.model_type == "t5":
+ model = T5ForConditionalGeneration.from_pretrained(args.model_dir)
+ elif args.model_type == "nmt":
+ from fairseq.models.transformer import TransformerModel
+ model = TransformerModel.from_pretrained(args.model_dir)
+ elif args.model_type == "bart":
+ if args.nougat:
+ model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
+ model = model.get_decoder()
+ else:
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)
+ elif args.model_type == "pix2struct":
+ model = Pix2StructForConditionalGeneration.from_pretrained(
+ args.model_dir)
+ elif args.model_type == "blip2":
+ model = Blip2ForConditionalGeneration.from_pretrained(
+ args.model_dir).language_model
+ return model
+
+
+def convert_checkpoint(args):
+
+ model = get_model(args)
+
+ saved_dir = Path(args.output_dir)
+ saved_dir.mkdir(parents=True, exist_ok=True)
+
+ encoder_saved_dir = saved_dir / "encoder"
+ encoder_saved_dir.mkdir(parents=True, exist_ok=True)
+ decoder_saved_dir = saved_dir / "decoder"
+ decoder_saved_dir.mkdir(parents=True, exist_ok=True)
+
+ world_size = args.tp_size * args.pp_size
+
+ kv_cache_quant_algo = None
+ quant_algo = None
+
+ model_type = args.model_type if args.model_type != "blip2" else "t5"
+ encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
+ args, model)
+
+ additional_settings = ["gated_act"]
+
+ if not args.nougat and args.model_type != "pix2struct":
+ tllm_encoder_config = {
+ 'architecture': "EncoderModel",
+ 'dtype': args.dtype,
+ 'logits_dtype': encoder_config.logits_dtype,
+ 'num_hidden_layers': encoder_config.n_layer,
+ 'num_attention_heads': encoder_config.n_head,
+ 'hidden_size': encoder_config.hidden_size,
+ 'norm_epsilon': encoder_config.layernorm_eps,
+ 'vocab_size': encoder_config.vocab_size,
+ 'position_embedding_type': encoder_config.position_embedding_type,
+ 'hidden_act': encoder_config.hidden_act,
+ 'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
+ },
+ 'mapping': {
+ 'world_size': world_size,
+ 'tp_size': args.tp_size,
+ 'pp_size': args.pp_size,
+ },
+ 'use_parallel_embedding': args.use_parallel_embedding,
+ 'embedding_sharding_dim': args.embedding_sharding_dim,
+ 'max_position_embeddings': encoder_config.n_positions,
+ 'num_key_value_heads': encoder_config.n_head,
+ 'head_size': encoder_config.head_size,
+ 'has_position_embedding': encoder_config.has_position_embedding,
+ 'layernorm_type': encoder_config.layernorm_type,
+ 'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias,
+ 'has_mlp_bias': encoder_config.has_mlp_bias,
+ 'has_model_final_layernorm':
+ encoder_config.has_model_final_layernorm,
+ 'has_embedding_layernorm': encoder_config.has_embedding_layernorm,
+ 'has_embedding_scale': encoder_config.has_embedding_scale,
+ 'intermediate_size': encoder_config.ffn_hidden_size,
+ 'q_scaling': encoder_config.q_scaling,
+ 'layernorm_position': encoder_config.layernorm_position,
+ 'mlp_type': encoder_config.mlp_type,
+ 'relative_attention': encoder_config.relative_attention,
+ 'max_distance': encoder_config.max_distance,
+ 'num_buckets': encoder_config.num_buckets,
+ 'model_type': encoder_config.model_type,
+ }
+
+ for additional_setting in additional_settings:
+ if hasattr(encoder_config, additional_setting):
+ tllm_encoder_config.update({
+ additional_setting:
+ getattr(encoder_config, additional_setting)
+ })
+
+ with (encoder_saved_dir / "config.json").open('w') as f:
+ json.dump(tllm_encoder_config, f, indent=4)
+
+ encoder_convert_args = dict(params=model.state_dict(),
+ component="encoder")
+ tllm_decoder_config = {
+ 'architecture': "DecoderModel",
+ 'dtype': args.dtype,
+ 'logits_dtype': decoder_config.logits_dtype,
+ 'num_hidden_layers': decoder_config.n_layer,
+ 'num_attention_heads': decoder_config.n_head,
+ 'hidden_size': decoder_config.hidden_size,
+ 'norm_epsilon': decoder_config.layernorm_eps,
+ 'vocab_size': decoder_config.vocab_size,
+ 'position_embedding_type': decoder_config.position_embedding_type,
+ 'hidden_act': decoder_config.hidden_act,
+ 'quantization': {
+ 'quant_algo': quant_algo,
+ 'kv_cache_quant_algo': kv_cache_quant_algo,
+ },
+ 'mapping': {
+ 'world_size': world_size,
+ 'tp_size': args.tp_size,
+ 'pp_size': args.pp_size,
+ },
+ 'use_parallel_embedding': args.use_parallel_embedding,
+ 'embedding_sharding_dim': args.embedding_sharding_dim,
+ 'max_position_embeddings': decoder_config.n_positions,
+ 'head_size': decoder_config.head_size,
+ 'has_position_embedding': decoder_config.has_position_embedding,
+ 'layernorm_type': decoder_config.layernorm_type,
+ 'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias,
+ 'has_mlp_bias': decoder_config.has_mlp_bias,
+ 'has_model_final_layernorm': decoder_config.has_model_final_layernorm,
+ 'has_embedding_layernorm': decoder_config.has_embedding_layernorm,
+ 'has_embedding_scale': decoder_config.has_embedding_scale,
+ 'intermediate_size': decoder_config.ffn_hidden_size,
+ 'q_scaling': decoder_config.q_scaling,
+ 'layernorm_position': decoder_config.layernorm_position,
+ 'mlp_type': decoder_config.mlp_type,
+ 'relative_attention': decoder_config.relative_attention,
+ 'max_distance': decoder_config.max_distance,
+ 'num_buckets': decoder_config.num_buckets,
+ 'model_type': decoder_config.model_type,
+ 'rescale_before_lm_head': decoder_config.rescale_before_lm_head,
+ 'encoder_hidden_size': decoder_config.encoder_hidden_size,
+ 'encoder_num_heads': decoder_config.encoder_num_heads,
+ 'encoder_head_size': decoder_config.encoder_head_size,
+ 'skip_cross_kv': args.skip_cross_kv,
+ 'use_implicit_relative_attention': args.use_implicit_relative_attention,
+ 'decoder_start_token_id': decoder_config.decoder_start_token_id,
+ 'eos_token_id': decoder_config.eos_token_id,
+ 'bos_token_id': decoder_config.bos_token_id,
+ 'pad_token_id': decoder_config.pad_token_id,
+ }
+ for additional_setting in additional_settings:
+ if hasattr(decoder_config, additional_setting):
+ tllm_decoder_config.update({
+ additional_setting:
+ getattr(decoder_config, additional_setting)
+ })
+
+ with (decoder_saved_dir / "config.json").open('w') as f:
+ json.dump(tllm_decoder_config, f, indent=4)
+
+ decoder_convert_args = dict(params=model.state_dict(), component="decoder")
+
+ if args.model_type == "nmt":
+ fairseq_config = vars(model.cfg.model) # Namespace --> dict
+ num_embeddings = fairseq_config['max_source_positions']
+ embedding_dim = fairseq_config['encoder_embed_dim']
+ padding_idx = model.models[0].encoder.embed_tokens.padding_idx # 1
+
+ sin_pos_embedding = model.models[
+ 0].encoder.embed_positions.get_embedding(
+ padding_idx + 1 + num_embeddings,
+ embedding_dim,
+ padding_idx=padding_idx) # [2 + num_embeddings, embed_dim]
+ sin_pos_embedding = sin_pos_embedding[2:, :] # remove offset embeddings
+
+ encoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
+ decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
+
+ if args.workers == 1:
+ if not args.nougat and args.model_type != "pix2struct":
+ convert(0, world_size, args, tllm_encoder_config,
+ encoder_convert_args, encoder_saved_dir)
+ convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,
+ decoder_saved_dir)
+ else:
+ if args.workers > world_size:
+ args.workers = world_size
+ LOGGER.info(f'Convert checkpoint using {args.workers} workers.')
+ import torch.multiprocessing as mp
+ if not args.nougat and args.model_type != "pix2struct":
+ mp.spawn(convert,
+ nprocs=args.workers,
+ args=(world_size, args, tllm_encoder_config,
+ encoder_convert_args, encoder_saved_dir))
+ mp.spawn(convert,
+ nprocs=args.workers,
+ args=(world_size, args, tllm_decoder_config,
+ decoder_convert_args, decoder_saved_dir))
+
+
+def convert(worker_rank, world_size, args, model_config, convert_args,
+ saved_dir):
+ for rank in range(worker_rank, world_size, args.workers):
+ rank_config = copy.deepcopy(PretrainedConfig.from_dict(model_config))
+ rank_config.set_rank(rank)
+ weights = globals(
+ )[f'convert_{rank_config.model_type}_weights_to_tllm_safetensors'](
+ config=rank_config, **convert_args)
+ safetensors.torch.save_file(weights,
+ f'{saved_dir}/rank{rank}.safetensors')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument(
+ '--model_type',
+ type=str,
+ default='t5',
+ choices=['t5', 'nmt', 'bart', 'pix2struct', 'blip2'],
+ help=
+ 'Multimodal type when this script is used for multimodal conversion.')
+
+ parser.add_argument('--tp_size',
+ type=int,
+ default=1,
+ help='N-way tensor parallelism size')
+ parser.add_argument('--pp_size',
+ type=int,
+ default=1,
+ help='N-way pipeline parallelism size')
+ parser.add_argument("--model_dir",
+ "-i",
+ type=str,
+ help="Path to the framework checkpoint file",
+ required=True)
+ parser.add_argument("--output_dir",
+ "-o",
+ type=str,
+ help="Path to the converted TRT-LLM model weight file",
+ required=True)
+ parser.add_argument(
+ "--workers",
+ type=int,
+ help="How many workers to spawn for conversion (default: 4)",
+ default=4)
+ parser.add_argument("--nougat",
+ action="store_true",
+ help="Model which uses vision encoder + mbart decoder")
+ parser.add_argument("--verbose",
+ action="store_true",
+ help="Provide verbose messages")
+ parser.add_argument(
+ '--use_parallel_embedding',
+ action="store_true",
+ default=False,
+ help=
+ 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
+ )
+ parser.add_argument(
+ '--embedding_sharding_dim',
+ type=int,
+ default=0,
+ choices=[0, 1],
+ help=
+ 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
+ 'To shard it along hidden dimension, set embedding_sharding_dim=1'
+ 'Note: embedding sharding is only enabled when embedding_sharding_dim = 0'
+ )
+ parser.add_argument(
+ '--use_weight_only',
+ default=False,
+ action="store_true",
+ help='Quantize weights for the various GEMMs to INT4/INT8.'
+ 'See --weight_only_precision to set the precision')
+ parser.add_argument(
+ '--weight_only_precision',
+ const='int8',
+ type=str,
+ nargs='?',
+ default='int8',
+ choices=['int8', 'int4'],
+ help=
+ 'Define the precision for the weights when using weight-only quantization.'
+ 'You must also use --use_weight_only for that argument to have an impact.'
+ )
+ parser.add_argument(
+ '--dtype',
+ type=str,
+ default='float16',
+ choices=['float16', 'float32', 'bfloat16'],
+ help=
+ 'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.'
+ )
+ parser.add_argument(
+ '--skip_cross_kv',
+ action='store_true',
+ help=
+ 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).'
+ )
+ parser.add_argument(
+ '--use_implicit_relative_attention',
+ action='store_true',
+ help=
+ 'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.'
+ )
+ args = parser.parse_args()
+ log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,
+ format=log_format)
+ LOGGER.info("\n=============== Argument ===============")
+ for key in vars(args):
+ LOGGER.info(f"{key}: {vars(args)[key]}")
+ LOGGER.info("========================================")
+
+ start_time = datetime.now()
+ convert_checkpoint(args)
+ stop_time = datetime.now()
+ run_time = (stop_time - start_time)
+ LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time))
diff --git a/deployment/tensorrt_llm/convert/helper.py b/deployment/tensorrt_llm/convert/helper.py
new file mode 100644
index 0000000..0d65242
--- /dev/null
+++ b/deployment/tensorrt_llm/convert/helper.py
@@ -0,0 +1,95 @@
+# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/helper.py
+
+import typing
+from typing import Union
+
+import numpy as np
+import torch # pytype: disable=import-error
+
+from tensorrt_llm._utils import str_dtype_to_torch
+
+
+def split(v: Union[np.ndarray, torch.Tensor],
+ tp_size: int,
+ tp_rank: int,
+ dim=0):
+ if tp_size == 1:
+ if isinstance(v, np.ndarray):
+ return np.ascontiguousarray(v.copy())
+ else:
+ return v.clone().detach()
+ assert len(v.shape) > 1 or dim == 0
+ if isinstance(v, np.ndarray):
+ return np.ascontiguousarray(
+ np.split(v, tp_size, axis=dim)[tp_rank].copy())
+ else:
+ assert v.shape[dim] % tp_size == 0, \
+ 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
+ split_size = v.shape[dim] // tp_size
+ return v.split(split_size, dim=dim)[tp_rank].clone().detach()
+
+
+def reshape(v: torch.Tensor, shape=None):
+ if shape is None:
+ return v.contiguous()
+ else:
+ return v.reshape(shape).contiguous()
+
+
+def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,
+ tp_rank, model_type, weight_shape, bias_shape):
+
+ qkv_module_names = get_qkv_module_name(model_type)
+
+ weight = {}
+
+ # fuse weights of q, k, v
+ q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight']
+ k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight']
+ v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight']
+
+ # fuse qkv weight
+ shape = q_w.shape # (do, din)
+ qkv_w = torch.cat([q_w, k_w, v_w],
+ dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din)
+ qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)
+ weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,
+ shape=weight_shape)
+
+ # fuse qkv biases if present
+ if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys(
+ ) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None:
+ q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias']
+ k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias']
+ v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias']
+ shape = q_b.shape[0] # (do,)
+ qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do)
+ qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)
+ weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,
+ shape=bias_shape)
+ return weight
+
+
+def get_qkv_module_name(model_type):
+ if model_type in ["t5", "blip2"]:
+ q = "q"
+ k = "k"
+ v = "v"
+ elif model_type == "bart" or model_type == "nmt":
+ q = "q_proj"
+ k = "k_proj"
+ v = "v_proj"
+ elif model_type == "pix2struct":
+ q = "query"
+ k = "key"
+ v = "value"
+ return {"q": q, "k": k, "v": v}
+
+
+def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
+ dtype: typing.Optional[np.dtype] = None):
+ if dtype is not None:
+ assert isinstance(dtype,
+ str), f"dtype must be str, but get type {type(dtype)}"
+ for name in params.keys():
+ params[name] = params[name].to(str_dtype_to_torch(dtype))
diff --git a/deployment/tensorrt_llm/convert_dolphin.sh b/deployment/tensorrt_llm/convert_dolphin.sh
new file mode 100644
index 0000000..252822e
--- /dev/null
+++ b/deployment/tensorrt_llm/convert_dolphin.sh
@@ -0,0 +1,47 @@
+#!/usr/bin/env bash
+set -ex
+
+############################################################################################
+# Reference: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.18.2/examples/multimodal#nougat
+############################################################################################
+
+export LD_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/tensorrt_libs/:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH
+
+# 1. Download Huggingface weights
+export MODEL_NAME="Dolphin"
+git clone https://huggingface.co/Bytedance/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
+
+
+export MAX_BATCH_SIZE=16
+export MAX_SEQ_LEN=4096
+export MAX_INPUT_LEN=10
+export MAX_ENCODER_INPUT_LEN=784
+
+# 2. Convert Huggingface weights into TRT-LLM checkpoints and build TRT engines using scripts in examples/enc_dec
+python ./convert/convert_checkpoint.py --model_type bart \
+ --model_dir tmp/hf_models/${MODEL_NAME} \
+ --output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \
+ --tp_size 1 \
+ --pp_size 1 \
+ --dtype bfloat16 \
+ --nougat
+
+
+trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \
+ --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \
+ --paged_kv_cache disable \
+ --moe_plugin disable \
+ --gemm_plugin bfloat16 \
+ --bert_attention_plugin bfloat16 \
+ --gpt_attention_plugin bfloat16 \
+ --remove_input_padding enable \
+ --max_beam_width 1 \
+ --max_batch_size ${MAX_BATCH_SIZE} \
+ --max_seq_len ${MAX_SEQ_LEN} \
+ --max_input_len ${MAX_INPUT_LEN} \
+ --max_encoder_input_len $((${MAX_BATCH_SIZE} * ${MAX_ENCODER_INPUT_LEN})) # MAX_BATCH_SIZE (max_batch_size) * MAX_ENCODER_INPUT_LEN (num_visual_features)
+
+# 3. Generate TensorRT engines for visual components and combine everything into final pipeline.
+python ./convert/build_visual_engine.py --model_type nougat \
+ --model_path tmp/hf_models/${MODEL_NAME} \
+ --max_batch_size ${MAX_BATCH_SIZE}
\ No newline at end of file
diff --git a/deployment/tensorrt_llm/dolphin_runner.py b/deployment/tensorrt_llm/dolphin_runner.py
new file mode 100644
index 0000000..ec6da0e
--- /dev/null
+++ b/deployment/tensorrt_llm/dolphin_runner.py
@@ -0,0 +1,219 @@
+"""
+Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+SPDX-License-Identifier: MIT
+"""
+
+import json
+import os
+from typing import Optional
+
+import tensorrt_llm
+import tensorrt_llm.profiler as profiler
+import torch
+from PIL import Image
+from pydantic import BaseModel, Field
+from tensorrt_llm import logger
+from tensorrt_llm import mpi_rank
+from tensorrt_llm.runtime import MultimodalModelRunner
+from transformers import AutoTokenizer, DonutProcessor
+
+
+class InferenceConfig(BaseModel):
+ max_new_tokens: int = Field(128, description="Maximum new tokens to generate")
+ batch_size: int = Field(1, description="Batch size for inference")
+ log_level: str = Field("info", description="Logging level")
+ visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files")
+ visual_engine_name: str = Field("model.engine", description="Visual engine filename")
+ llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files")
+ hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory")
+ input_text: Optional[str] = Field(None, description="Input text for inference")
+ num_beams: int = Field(1, description="Number of beams for beam search")
+ top_k: int = Field(1, description="Top-k sampling value")
+ top_p: float = Field(0.0, description="Top-p (nucleus) sampling value")
+ temperature: float = Field(1.0, description="Sampling temperature")
+ repetition_penalty: float = Field(1.0, description="Repetition penalty factor")
+ run_profiling: bool = Field(False, description="Enable profiling mode")
+ profiling_iterations: int = Field(20, description="Number of profiling iterations")
+ check_accuracy: bool = Field(False, description="Enable accuracy checking")
+ video_path: Optional[str] = Field(None, description="Path to input video file")
+ video_num_frames: Optional[int] = Field(None, description="Number of video frames to process")
+ image_path: Optional[str] = Field(None, description="Path to input image file")
+ path_sep: str = Field(",", description="Path separator character")
+ prompt_sep: str = Field(",", description="Prompt separator character")
+ enable_context_fmha_fp32_acc: Optional[bool] = Field(
+ None,
+ description="Enable FP32 accumulation for context FMHA"
+ )
+ enable_chunked_context: bool = Field(False, description="Enable chunked context processing")
+ use_py_session: bool = Field(False, description="Use Python session instead of C++")
+ kv_cache_free_gpu_memory_fraction: float = Field(
+ 0.9,
+ description="Fraction of GPU memory free for KV cache",
+ ge=0.0, le=1.0
+ )
+ cross_kv_cache_fraction: float = Field(
+ 0.5,
+ description="Fraction of cross-attention KV cache",
+ ge=0.0, le=1.0
+ )
+ multi_block_mode: bool = Field(True, description="Enable multi-block processing mode")
+
+
+class DolphinRunner(MultimodalModelRunner):
+ def __init__(self, args):
+ self.args = args
+
+ self.runtime_rank = mpi_rank()
+ device_id = self.runtime_rank % torch.cuda.device_count()
+ torch.cuda.set_device(device_id)
+ self.device = "cuda:%d" % (device_id)
+
+ self.stream = torch.cuda.Stream(torch.cuda.current_device())
+ torch.cuda.set_stream(self.stream)
+
+ # parse model type from visual engine config
+ with open(os.path.join(self.args.visual_engine_dir, "config.json"),
+ "r") as f:
+ config = json.load(f)
+ self.model_type = config['builder_config']['model_type']
+ self.vision_precision = config['builder_config']['precision']
+ self.decoder_llm = not (
+ 't5' in self.model_type
+ or self.model_type in ['nougat', 'pix2struct']
+ ) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
+
+ if self.model_type == "mllama":
+ self.vision_input_names = [
+ "pixel_values",
+ "aspect_ratio_ids",
+ "aspect_ratio_mask",
+ ]
+ self.vision_output_names = [
+ "output",
+ ]
+ else:
+ self.vision_input_names = ["input"]
+ self.vision_output_names = ["output"]
+
+ self.use_py_session = True
+
+ self.init_image_encoder()
+ self.init_tokenizer()
+ self.init_processor()
+ self.init_llm()
+
+ def init_tokenizer(self):
+ assert self.model_type == 'nougat'
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir)
+ self.tokenizer.padding_side = "right"
+
+ def init_processor(self):
+ assert self.model_type == 'nougat'
+ self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True)
+
+ def run(self, input_texts, input_images, max_new_tokens):
+ prompts = [f"{text.strip()} " for text in input_texts]
+ images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda")
+ prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda")
+ prompt_ids = prompt_ids.to(
+ torch.int32) # Important! If the type of prompt_ids is not int32, the output will be wrong.
+
+ logger.info("---------------------------------------------------------")
+ logger.info(f"images size: {images.size()}")
+ logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}")
+ logger.info("---------------------------------------------------------")
+
+ output_texts = self.generate(input_texts,
+ [None] * len(input_texts),
+ images,
+ prompt_ids,
+ max_new_tokens,
+ warmup=False,
+ )
+
+ return output_texts
+
+ def generate(self,
+ pre_prompt,
+ post_prompt,
+ image,
+ decoder_input_ids,
+ max_new_tokens,
+ warmup=False,
+ other_vision_inputs={},
+ other_decoder_inputs={}):
+ if not warmup:
+ profiler.start("Generate")
+ input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
+ warmup, pre_prompt, post_prompt, image, other_vision_inputs)
+
+ if warmup: return None
+
+ # use prompt tuning to pass multimodal features
+ # model.generate() expects the following params (see layers/embedding.py):
+ # args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size]
+ # args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated)
+ # args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len
+ profiler.start("LLM")
+ if self.model_type in ['nougat', 'pix2struct']:
+ # Trim encoder input_ids to match visual features shape
+ ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1])
+ if self.model_type == 'nougat':
+ input_ids = torch.zeros(ids_shape, dtype=torch.int32)
+ elif self.model_type == 'pix2struct':
+ input_ids = torch.ones(ids_shape, dtype=torch.int32)
+
+ output_ids = self.model.generate(
+ input_ids,
+ decoder_input_ids,
+ max_new_tokens,
+ num_beams=self.args.num_beams,
+ bos_token_id=self.tokenizer.bos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ debug_mode=False,
+ prompt_embedding_table=ptuning_args[0],
+ prompt_tasks=ptuning_args[1],
+ prompt_vocab_size=ptuning_args[2],
+ )
+ profiler.stop("LLM")
+
+ if mpi_rank() == 0:
+ # Extract a list of tensors of shape beam_width x output_ids.
+ output_beams_list = [
+ self.tokenizer.batch_decode(
+ output_ids[batch_idx, :, decoder_input_ids.shape[1]:],
+ skip_special_tokens=False) for batch_idx in range(
+ min(self.args.batch_size, decoder_input_ids.shape[0]))
+ ]
+
+ stripped_text = [[
+ output_beams_list[batch_idx][beam_idx].replace("", "").replace("", "").strip()
+ for beam_idx in range(self.args.num_beams)
+ ] for batch_idx in range(
+ min(self.args.batch_size, decoder_input_ids.shape[0]))]
+ profiler.stop("Generate")
+ return stripped_text
+ else:
+ profiler.stop("Generate")
+ return None
+
+
+if __name__ == "__main__":
+ config = InferenceConfig(
+ max_new_tokens=4024,
+ batch_size=16,
+ log_level="info",
+ hf_model_dir=f"./tmp/hf_models/Dolphin",
+ visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder",
+ llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16",
+ )
+
+ model = DolphinRunner(config)
+
+ image_path = "../../demo/page_imgs/page_1.jpeg"
+ prompt = "Parse the reading order of this document."
+ image = Image.open(image_path).convert("RGB")
+ output_texts = model.run([prompt], [image], 4024)
+ output_texts = [texts[0] for texts in output_texts]
+ print(output_texts)
diff --git a/deployment/tensorrt_llm/run_dolphin.py b/deployment/tensorrt_llm/run_dolphin.py
new file mode 100644
index 0000000..947de52
--- /dev/null
+++ b/deployment/tensorrt_llm/run_dolphin.py
@@ -0,0 +1,106 @@
+"""
+Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+SPDX-License-Identifier: MIT
+"""
+
+import argparse
+import os
+
+import tensorrt_llm
+import tensorrt_llm.profiler as profiler
+from PIL import Image
+from tensorrt_llm import logger
+from tensorrt_llm import mpi_rank
+from tensorrt_llm.runtime import MultimodalModelRunner
+
+from dolphin_runner import DolphinRunner
+from utils import add_common_args
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+def print_result(model, input_text, output_text, args):
+ logger.info("---------------------------------------------------------")
+ logger.info(f"\n[Q] {input_text}")
+ for i in range(len(output_text)):
+ logger.info(f"\n[A]: {output_text[i]}")
+
+ if args.num_beams == 1:
+ output_ids = model.tokenizer(output_text[0][0],
+ add_special_tokens=False)['input_ids']
+ logger.info(f"Generated {len(output_ids)} tokens")
+
+ if args.check_accuracy:
+ if model.model_type != 'nougat':
+ if model.model_type == "vila":
+ for i in range(len(args.image_path.split(args.path_sep))):
+ if i % 2 == 0:
+ assert output_text[i][0].lower(
+ ) == "the image captures a bustling city intersection teeming with life. from the perspective of a car's dashboard camera, we see"
+ else:
+ assert output_text[i][0].lower(
+ ) == "the image captures the iconic merlion statue in singapore, a renowned worldwide landmark. the merlion, a mythical"
+ elif model.model_type == "llava":
+ for i in range(len(args.image_path.split(args.path_sep))):
+ assert output_text[i][0].lower() == 'singapore'
+ elif model.model_type == 'fuyu':
+ assert output_text[0][0].lower() == '4'
+ elif model.model_type == "pix2struct":
+ assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[
+ 0][0].lower()
+ elif model.model_type in [
+ 'blip2', 'neva', 'phi-3-vision', 'llava_next'
+ ]:
+ assert 'singapore' in output_text[0][0].lower()
+ elif model.model_type == 'video-neva':
+ assert 'robot' in output_text[0][0].lower()
+ elif model.model_type == 'kosmos-2':
+ assert 'snowman' in output_text[0][0].lower()
+ elif model.model_type == "mllama":
+ if "If I had to write a haiku for this one" in input_text:
+ assert "it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a" in output_text[
+ 0][0] or "Here is a haiku for the image:\n\n" in output_text[
+ 0][0], f"expected results: 'it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a', generated results: '{output_text[0][0]}'"
+ elif "The key to life is" in input_text:
+ assert "to find your passion and pursue it with all your heart." in output_text[
+ 0][0] or "not to be found in the external world," in output_text[
+ 0][0], f"expected results: 'to find your passion and pursue it with all your heart.', generated results: '{output_text[0][0]}'"
+ elif model.model_type == 'llava_onevision':
+ if args.video_path is None:
+ assert 'singapore' in output_text[0][0].lower()
+ else:
+ assert 'the video is funny because the child\'s actions are' in output_text[
+ 0][0].lower()
+ elif model.model_type == "qwen2_vl":
+ assert 'dog' in output_text[0][0].lower()
+ else:
+ assert output_text[0][0].lower() == 'singapore'
+
+ if args.run_profiling:
+ msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(
+ name) / args.profiling_iterations
+ logger.info('Latencies per batch (msec)')
+ logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision')))
+ logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM')))
+ logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate')))
+
+ logger.info("---------------------------------------------------------")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser = add_common_args(parser)
+ args = parser.parse_args()
+ logger.set_level(args.log_level)
+
+ model = DolphinRunner(args)
+
+ input_image = Image.open(args.image_path[0]).convert('RGB')
+ num_iters = args.profiling_iterations if args.run_profiling else 1
+
+ for _ in range(num_iters):
+ output_texts = model.run(args.input_text, [input_image], args.max_new_tokens)
+
+ runtime_rank = tensorrt_llm.mpi_rank()
+ if runtime_rank == 0:
+ print_result(model, args.input_text, output_texts, args)
diff --git a/deployment/tensorrt_llm/run_dolphin.sh b/deployment/tensorrt_llm/run_dolphin.sh
new file mode 100644
index 0000000..affefa6
--- /dev/null
+++ b/deployment/tensorrt_llm/run_dolphin.sh
@@ -0,0 +1,47 @@
+#!/usr/bin/env bash
+set -ex
+
+export MODEL_NAME="Dolphin"
+
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Parse the reading order of this document." \
+ --image_path "../../demo/page_imgs/page_1.jpeg"
+
+
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Read text in the image." \
+ --image_path "../../demo/element_imgs/block_formula.jpeg"
+
+
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Read text in the image." \
+ --image_path "../../demo/element_imgs/para_1.jpg"
+
+
+python run_dolphin.py \
+ --batch_size 1 \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_new_tokens 4096 \
+ --repetition_penalty 1.0 \
+ --input_text "Parse the table in the image." \
+ --image_path "../../demo/element_imgs/table_1.jpeg"
diff --git a/deployment/tensorrt_llm/start_dolphin_server.sh b/deployment/tensorrt_llm/start_dolphin_server.sh
new file mode 100644
index 0000000..128f0a3
--- /dev/null
+++ b/deployment/tensorrt_llm/start_dolphin_server.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+set -ex
+
+export MODEL_NAME="Dolphin"
+
+python api_server.py \
+ --hf_model_dir tmp/hf_models/${MODEL_NAME} \
+ --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
+ --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
+ --max_batch_size 16
\ No newline at end of file
diff --git a/deployment/tensorrt_llm/utils.py b/deployment/tensorrt_llm/utils.py
new file mode 100644
index 0000000..d7ac44c
--- /dev/null
+++ b/deployment/tensorrt_llm/utils.py
@@ -0,0 +1,95 @@
+def add_common_args(parser):
+ parser.add_argument('--max_new_tokens', type=int, default=128)
+ parser.add_argument('--batch_size', type=int, default=1)
+ parser.add_argument('--log_level', type=str, default='info')
+ parser.add_argument('--visual_engine_dir',
+ type=str,
+ default=None,
+ help='Directory containing visual TRT engines')
+ parser.add_argument('--visual_engine_name',
+ type=str,
+ default='model.engine',
+ help='Name of visual TRT engine')
+ parser.add_argument('--llm_engine_dir',
+ type=str,
+ default=None,
+ help='Directory containing TRT-LLM engines')
+ parser.add_argument('--hf_model_dir',
+ type=str,
+ default=None,
+ help="Directory containing tokenizer")
+ parser.add_argument('--input_text',
+ type=str,
+ nargs='+',
+ default=None,
+ help='Text prompt to LLM')
+ parser.add_argument('--num_beams',
+ type=int,
+ help="Use beam search if num_beams >1",
+ default=1)
+ parser.add_argument('--top_k', type=int, default=1)
+ parser.add_argument('--top_p', type=float, default=0.0)
+ parser.add_argument('--temperature', type=float, default=1.0)
+ parser.add_argument('--repetition_penalty', type=float, default=1.0)
+ parser.add_argument('--run_profiling',
+ action='store_true',
+ help='Profile runtime over several iterations')
+ parser.add_argument('--profiling_iterations',
+ type=int,
+ help="Number of iterations to run profiling",
+ default=20)
+ parser.add_argument('--check_accuracy',
+ action='store_true',
+ help='Check correctness of text output')
+ parser.add_argument("--image_path",
+ type=str,
+ nargs='+',
+ default=None,
+ help='List of input image paths, separated by symbol')
+ parser.add_argument("--path_sep",
+ type=str,
+ default=",",
+ help='Path separator symbol')
+ parser.add_argument("--prompt_sep",
+ type=str,
+ default=",",
+ help="Prompt separator symbol")
+ parser.add_argument('--enable_context_fmha_fp32_acc',
+ action='store_true',
+ default=None,
+ help="Enable FMHA runner FP32 accumulation.")
+ parser.add_argument(
+ '--enable_chunked_context',
+ action='store_true',
+ help='Enables chunked context (only available with cpp session).',
+ )
+ parser.add_argument(
+ '--use_py_session',
+ default=False,
+ action='store_true',
+ help=
+ "Whether or not to use Python runtime session. By default C++ runtime session is used for the LLM."
+ )
+ parser.add_argument(
+ '--kv_cache_free_gpu_memory_fraction',
+ default=0.9,
+ type=float,
+ help='Specify the free gpu memory fraction.',
+ )
+ parser.add_argument(
+ '--cross_kv_cache_fraction',
+ default=0.5,
+ type=float,
+ help=
+ 'Specify the kv cache fraction reserved for cross attention. Only applicable for encoder-decoder models. By default 0.5 for self and 0.5 for cross.',
+ )
+ parser.add_argument(
+ '--multi_block_mode',
+ type=lambda s: s.lower() in
+ ("yes", "true", "t", "1"
+ ), # custom boolean function to convert input string to boolean
+ default=True,
+ help=
+ "Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel."
+ )
+ return parser