Dolphin/deployment/tensorrt_llm/convert/convert_checkpoint.py
2025-06-30 19:41:03 +08:00

1529 lines
67 KiB
Python

# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/convert_checkpoint.py
import argparse
import configparser
import copy
import json
import logging
import os
import types
from ast import literal_eval
from datetime import datetime
from pathlib import Path
import safetensors
from helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split
from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
MBartForConditionalGeneration,
Pix2StructForConditionalGeneration,
T5ForConditionalGeneration, VisionEncoderDecoderModel)
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType)
from tensorrt_llm.models import PretrainedConfig
dir_path = os.path.dirname(os.path.realpath(__file__))
LOGGER = logging.getLogger(__name__)
layernorm_type_map = {i.name: i.value for i in LayerNormType}
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
mlp_type_map = {i.name: i.value for i in MLPType}
def copy_args_to_component_config(component_config, args):
for arg in vars(args):
setattr(component_config, arg, getattr(args, arg))
return component_config
def parse_t5_config(args, hf_model):
config = configparser.ConfigParser()
config["encoder"] = {}
for key, val in hf_model.encoder.config.to_dict().items():
config["encoder"][key] = f"{val}"
# manually set q_scaling to offset attention scaling's effect.
# TODO: modify kernels to control whether to disable attention scaling
def get_offset_q_scaling(config):
scaling = 1 / config.head_size**.5
return scaling
config["decoder"] = {}
for key, val in hf_model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
config["structure"] = dict()
config["structure"]["t5_with_bias"] = "false"
config["structure"]["use_gated_activation"] = str(
hf_model.encoder.config.is_gated_act)
config["structure"]["position_embedding_type"] = "relative"
config["structure"]["model_type"] = args.model_type
def parse_t5_config_by_component(config, component, args):
component_config = types.SimpleNamespace()
component_config = copy_args_to_component_config(component_config, args)
component_config.n_head = config.getint(component, 'num_heads')
component_config.head_size = config.getint(component, 'd_kv')
component_config.hidden_size = config.getint(component, 'd_model')
component_config.ffn_hidden_size = config.getint(component, 'd_ff')
component_config.vocab_size = config.getint(component, 'vocab_size')
component_config.n_positions = config.getint(component,
'n_positions',
fallback=512)
component_config.has_position_embedding = config.getboolean(
component, 'has_position_embedding',
fallback=False) # TODO: hardcoded here
component_config.has_token_type_embedding = config.getboolean(
component, 'has_token_type_embedding', fallback=False)
component_config.has_embedding_layernorm = config.getboolean(
component, 'has_embedding_layernorm', fallback=False)
component_config.has_embedding_scale = config.getboolean(
component, 'has_embedding_scale', fallback=False)
component_config.q_scaling = get_offset_q_scaling(component_config)
component_config.has_attention_qkvo_bias = config.getboolean(
component, 'has_attention_qkvo_bias',
fallback=False) # TODO: hardcoded here
component_config.has_mlp_bias = config.getboolean(component,
'has_mlp_bias',
fallback=False)
component_config.has_model_final_layernorm = config.getboolean(
component, 'has_model_final_layernorm', fallback=True)
component_config.layernorm_eps = config.getfloat(
component, 'layer_norm_epsilon')
component_config.layernorm_position = layernorm_position_map[config.get(
component, 'layernorm_position',
fallback='pre_layernorm')] # TODO: hardcoded here
component_config.layernorm_type = layernorm_type_map[config.get(
component, 'layernorm_type', fallback='RmsNorm')]
component_config.hidden_act = config.get(component, 'dense_act_fn')
component_config.gated_act = config.getboolean(component,
'is_gated_act')
component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
gated_act else 'MLP']
component_config.num_buckets = config.getint(
component, 'relative_attention_num_buckets')
component_config.max_distance = config.getint(
component, 'relative_attention_max_distance')
component_config.position_embedding_type = config.get(
'structure', 'position_embedding_type')
component_config.logits_dtype = config.get(component,
'logits_dtype',
fallback='float32')
if component == 'encoder':
component_config.n_layer = config.getint(component, 'num_layers')
component_config.relative_attention = config.get(
'structure', 'position_embedding_type') == 'relative'
elif component == 'decoder':
component_config.n_layer = config.getint(component,
'num_decoder_layers')
component_config.has_lm_head_bias = config.getboolean(
component, # TODO: T5 with bias
'has_lm_head_bias',
fallback=False)
component_config.relative_attention = config.getboolean(
component, 'relative_attention', fallback=True)
component_config.rescale_before_lm_head = config.getboolean(
component, 'tie_word_embeddings'
) # default is True (for T5), but False for Flan-T5
component_config.encoder_hidden_size = config.getint(
'encoder', 'd_model')
component_config.encoder_num_heads = config.getint(
'encoder', 'num_heads')
component_config.encoder_head_size = config.getint(
'encoder', 'd_kv')
component_config.decoder_start_token_id = config.getint(
'decoder', 'decoder_start_token_id')
component_config.eos_token_id = config.getint(
'decoder', 'eos_token_id')
bos_token_id = config.get('decoder', 'bos_token_id')
# T5 does not have bos_token_id
component_config.bos_token_id = int(
bos_token_id) if bos_token_id != "None" else None
component_config.pad_token_id = config.getint(
'decoder', 'pad_token_id')
else:
assert False, 'Unsupported component!'
return component_config
encoder_config = parse_t5_config_by_component(config, "encoder", args)
decoder_config = parse_t5_config_by_component(config, "decoder", args)
return encoder_config, decoder_config
def convert_t5_weights_to_tllm_safetensors(config, component, params):
weights = {}
mapping = config.mapping
convert_weight_to_dtype(params, config.dtype)
hidden_size = config.hidden_size
ffn_hidden_size = config.intermediate_size
num_layers = config.num_hidden_layers
n_head = config.num_attention_heads
head_size = config.head_size
attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5
hf_param_prefix = f'{component}'
trtllm_layer_name = f'{component}_layers'
trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
hf_component_idx = 1 if component == 'encoder' else 2
def get_attn_module_name(component, block, layer, attn_type):
return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}'
weights['embedding.vocab_embedding.weight'] = reshape(
params['shared.weight'].clone(), None)
layers_range = mapping.pp_layers(num_layers)
for layer_idx in layers_range:
local_layer_idx = layer_idx - layers_range[0]
trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}'
hidden_layer_name_split = {
f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': {
"name":
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
"shape":
(hidden_size, attention_hidden_size // mapping.tp_size),
"split_dim": -1
},
f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight':
{
"name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
"shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
"split_dim": -1
},
f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight':
{
"name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight':
{
"name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
}
hidden_layer_name_no_split = {
f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': {
"name":
f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
"shape": None
},
f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight':
{
"name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
"shape": None
},
}
if config.gated_act:
hidden_layer_name_split.update({
f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight':
{
"name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight':
{
"name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
})
if component == 'decoder':
hidden_layer_name_split.update({
f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': {
"name":
f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
"shape":
(hidden_size, attention_hidden_size // mapping.tp_size),
"split_dim": -1
},
})
hidden_layer_name_no_split.update({
f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': {
"name":
f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
"shape": None
},
})
self_attn_module_name = get_attn_module_name(
component, layer_idx, "1", 'EncDecAttention')
weights.update(
fuse_qkv_one_layer(
params, self_attn_module_name,
f'{trtllm_layer_name_prefix}.cross_attention',
mapping.tp_size, mapping.tp_rank, config.model_type,
(attention_hidden_size * 3 // mapping.tp_size, hidden_size),
None))
self_attn_module_name = get_attn_module_name(component, layer_idx, "0",
'SelfAttention')
weights.update(
fuse_qkv_one_layer(
params, self_attn_module_name,
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
mapping.tp_size, mapping.tp_rank, config.model_type,
(attention_hidden_size * 3 // mapping.tp_size, hidden_size),
None))
weights[
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
split(
params[
f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
.T, mapping.tp_size, mapping.tp_rank, 0),
(n_head // mapping.tp_size, config.num_buckets))
for hf_weight_name, weight_info in hidden_layer_name_split.items():
if hf_weight_name in params.keys():
weights[weight_info["name"]] = reshape(
split(params[hf_weight_name],
mapping.tp_size,
mapping.tp_rank,
dim=weight_info["split_dim"]), weight_info["shape"])
for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
if hf_weight_name in params.keys():
weights[weight_info["name"]] = reshape(
params[hf_weight_name].clone(), shape=weight_info["shape"])
weights['final_layernorm.weight'] = reshape(
params[f'{component}.final_layer_norm.weight'].clone(), None)
if component == 'decoder':
weights['lm_head.weight'] = reshape(
split(params['lm_head.weight'],
mapping.tp_size,
mapping.tp_rank,
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
if not config.use_implicit_relative_attention:
weights['rel_attn_table'] = reshape(
split(
params[
f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
.T, mapping.tp_size, mapping.tp_rank, 0),
(n_head // mapping.tp_size, config.num_buckets))
return weights
convert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors # func alias
def parse_nmt_config(args, model):
config = configparser.ConfigParser()
fairseq_config = vars(model.cfg.model) # Namespace --> dict
config['encoder'] = dict()
for key, val in fairseq_config.items():
config["encoder"][key] = f"{val}"
config["encoder"]["q_scaling"] = '1'
# NMT has final layernorm for pre-norm model architecture.
config['encoder']['has_model_final_layernorm'] = config['encoder'][
'encoder_normalize_before']
config['encoder']['vocab_size'] = str(len(model.src_dict)) # fairseq naming
config['decoder'] = dict()
for key, val in fairseq_config.items():
config["decoder"][key] = f"{val}"
config["decoder"]["q_scaling"] = '1'
config["decoder"]["rescale_before_lm_head"] = 'false'
config['decoder']['has_model_final_layernorm'] = str(
config['decoder'].getboolean('decoder_normalize_before', False)
and not config['decoder'].getboolean('no_decoder_final_norm', False))
config['decoder']['vocab_size'] = str(len(model.tgt_dict)) # fairseq naming
config["structure"] = dict()
config["structure"]["t5_with_bias"] = "true"
config["structure"]["use_gated_activation"] = "false"
config["structure"][
"position_embedding_type"] = "learned_absolute" # "sinusoid"
config["structure"]["model_type"] = args.model_type
def parse_nmt_config_by_component(config, component, args):
assert component in ('encoder', 'decoder'), 'Unsupported component!'
component_config = types.SimpleNamespace()
component_config = copy_args_to_component_config(component_config, args)
component_config.n_layer = config.getint(component,
f'{component}_layers')
component_config.n_head = config.getint(component,
f'{component}_attention_heads')
component_config.hidden_size = config.getint(
component, f'{component}_embed_dim') # fairseq naming
component_config.head_size = config.getint(
component,
'd_kv',
fallback=component_config.hidden_size // component_config.n_head)
component_config.ffn_hidden_size = config.getint(
component, f'{component}_ffn_embed_dim') # fairseq naming
component_config.vocab_size = config.getint(component, 'vocab_size')
component_config.n_positions = config.getint(
component, 'max_source_positions') # fairseq naming
component_config.has_position_embedding = not config.getboolean(
component, 'no_token_positional_embeddings',
fallback=False) # fairseq naming
component_config.has_token_type_embedding = config.getboolean(
component, 'has_token_type_embedding', fallback=False)
component_config.has_embedding_layernorm = config.getboolean(
component, 'layernorm_embedding', fallback=True) # fairseq naming
component_config.has_embedding_scale = not config.getboolean(
component, 'no_scale_embedding') # fairseq naming
component_config.q_scaling = config.getfloat(component,
'q_scaling',
fallback=1.0)
component_config.has_attention_qkvo_bias = config.getboolean(
'structure', 't5_with_bias', fallback=True)
component_config.has_mlp_bias = config.getboolean('structure',
't5_with_bias',
fallback=True)
component_config.has_model_final_layernorm = config.getboolean(
component, 'has_model_final_layernorm')
component_config.layernorm_eps = config.getfloat(
component, 'layer_norm_epsilon', fallback=1e-5) # fairseq naming
normalize_before = config.getboolean(
component, f'{component}_normalize_before') # fairseq naming
component_config.layernorm_position = layernorm_position_map[
'pre_layernorm' if normalize_before else 'post_layernorm']
component_config.layernorm_type = layernorm_type_map[config.get(
component, 'layernorm_type', fallback='LayerNorm')]
component_config.hidden_act = config.get(
component, 'activation_fn') # fairseq naming
component_config.gated_act = config.getboolean(component,
'is_gated_act',
fallback=False)
component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
gated_act else 'MLP']
component_config.relative_attention = config.get(
'structure', 'position_embedding_type') == 'relative'
component_config.num_buckets = config.getint(
component, 'relative_attention_num_buckets', fallback=0)
component_config.max_distance = config.getint(
component, 'relative_attention_max_distance', fallback=0)
component_config.position_embedding_type = config.get(
'structure', 'position_embedding_type')
component_config.logits_dtype = config.get(component,
'logits_dtype',
fallback='float32')
if component == 'decoder':
component_config.rescale_before_lm_head = config.getboolean(
component, 'rescale_before_lm_head')
component_config.encoder_hidden_size = config.getint(
'encoder', 'encoder_embed_dim') # fairseq naming
component_config.encoder_num_heads = config.getint(
'encoder', 'encoder_attention_heads')
component_config.encoder_head_size = config.getint(
'encoder',
'd_kv',
fallback=component_config.encoder_hidden_size //
component_config.encoder_num_heads)
component_config.decoder_start_token_id = None
component_config.eos_token_id = None
component_config.bos_token_id = None
component_config.pad_token_id = None
return component_config
encoder_config = parse_nmt_config_by_component(config, "encoder", args)
decoder_config = parse_nmt_config_by_component(config, "decoder", args)
return encoder_config, decoder_config
def convert_nmt_weights_to_tllm_safetensors(config, component, params,
sin_pos_embedding):
weights = {}
mapping = config.mapping
hidden_size = config.hidden_size
convert_weight_to_dtype(params, config.dtype)
ffn_hidden_size = config.intermediate_size
vocab_size = config.vocab_size
hf_param_prefix = f'models.0.{component}'
trtllm_layer_name = f'{component}_layers'
trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
hidden_layer_name_split = {
'self_attn.out_proj.weight': {
"name": f'{trtllm_attn_layer_name}.dense.weight',
"shape": (hidden_size, hidden_size // mapping.tp_size),
"split_dim": -1
},
'fc1.weight': {
"name": 'mlp.fc.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
'fc1.bias': {
"name": 'mlp.fc.bias',
"shape": (ffn_hidden_size // mapping.tp_size),
"split_dim": 0
},
'fc2.weight': {
"name": 'mlp.proj.weight',
"shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
"split_dim": -1
},
}
hidden_layer_name_no_split = {
'self_attn.out_proj.bias': {
"name": f'{trtllm_attn_layer_name}.dense.bias',
"shape": (hidden_size)
},
'self_attn_layer_norm.weight': {
"name": f'{trtllm_attn_layernorm_name}.weight',
"shape": None
},
'self_attn_layer_norm.bias': {
"name": f'{trtllm_attn_layernorm_name}.bias',
"shape": None
},
'fc2.bias': {
"name": 'mlp.proj.bias',
"shape": (hidden_size)
},
'final_layer_norm.weight': {
"name": 'mlp_layernorm.weight',
"shape": None
},
'final_layer_norm.bias': {
"name": 'mlp_layernorm.bias',
"shape": None
},
}
if component == "decoder":
hidden_layer_name_split.update({
'encoder_attn.out_proj.weight': {
"name": 'cross_attention.dense.weight',
"shape": (hidden_size, hidden_size // mapping.tp_size),
"split_dim": -1
},
})
hidden_layer_name_no_split.update({
'encoder_attn.out_proj.bias': {
"name": 'cross_attention.dense.bias',
"shape": (hidden_size)
},
'encoder_attn_layer_norm.weight': {
"name": 'cross_attention_layernorm.weight',
"shape": None,
},
'encoder_attn_layer_norm.bias': {
"name": 'cross_attention_layernorm.bias',
"shape": None
},
})
def get_attn_module_name(component, layer, attn_type):
return f'models.0.{component}.layers.{int(layer)}.{attn_type}'
weights["embedding.vocab_embedding.weight"] = reshape(
params[f'{hf_param_prefix}.embed_tokens.weight'].clone(),
(vocab_size, -1))
weights["embedding.position_embedding.weight"] = reshape(
sin_pos_embedding, (config.max_position_embeddings, hidden_size))
num_layers = config.num_hidden_layers
layers_range = mapping.pp_layers(num_layers)
for layer_idx in layers_range:
local_layer_idx = layer_idx - layers_range[0]
hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
for hf_weight_name, weight_info in hidden_layer_name_split.items():
weights[
f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
mapping.tp_size,
mapping.tp_rank,
dim=weight_info["split_dim"]), weight_info["shape"])
for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
weights[trtllm_layer_fullname] = reshape(
params[hf_layer_fullname].clone(), shape=weight_info["shape"])
self_attn_module_name = get_attn_module_name(component, layer_idx,
'self_attn')
weights.update(
fuse_qkv_one_layer(
params, self_attn_module_name,
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
mapping.tp_size, mapping.tp_rank, config.model_type,
(hidden_size * 3 // mapping.tp_size, hidden_size),
(hidden_size * 3 // mapping.tp_size)))
if component == 'decoder':
cross_attn_module_name = get_attn_module_name(
component, layer_idx, 'encoder_attn')
weights.update(
fuse_qkv_one_layer(
params, cross_attn_module_name,
f'{trtllm_layer_name_prefix}.cross_attention',
mapping.tp_size, mapping.tp_rank, config.model_type,
(hidden_size * 3 // mapping.tp_size, hidden_size),
(hidden_size * 3 // mapping.tp_size)))
if component == 'decoder':
weights['lm_head.weight'] = reshape(
split(params[f'{hf_param_prefix}.output_projection.weight'],
mapping.tp_size,
mapping.tp_rank,
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
if config.has_model_final_layernorm:
weights['final_layernorm.weight'] = params[
f'{hf_param_prefix}.layer_norm.weight'].clone()
weights['final_layernorm.bias'] = params[
f'{hf_param_prefix}.layer_norm.bias'].clone()
return weights
def parse_bart_config(args, hf_model):
config = configparser.ConfigParser()
config['decoder'] = dict()
for key, val in hf_model.model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
config["decoder"]["q_scaling"] = '1'
config["decoder"]["rescale_before_lm_head"] = str(False)
config['decoder']['has_model_final_layernorm'] = str(
args.nougat or isinstance(hf_model, MBartForConditionalGeneration))
if args.nougat:
# These flags are true for mbart decoders, but missing in HF config
config['decoder']['normalize_before'] = str(True)
config['decoder']['normalize_embeddings'] = str(True)
config['encoder'] = dict()
# Init few encoder configs, needed by build, from decoder config
encoder_config_keys = [
"encoder_ffn_dim", "encoder_layers", "encoder_attention_heads",
"encoder_layerdrop", "d_model"
]
for key in encoder_config_keys:
config['encoder'][key] = config['decoder'][key]
else:
config['encoder'] = dict()
for key, val in hf_model.model.encoder.config.to_dict().items():
config["encoder"][key] = f"{val}"
config["encoder"]["q_scaling"] = '1'
# mBART has final layernorm, BART does not
config['encoder']['has_model_final_layernorm'] = str(
isinstance(hf_model, MBartForConditionalGeneration))
config["structure"] = dict()
config["structure"]["t5_with_bias"] = "true"
config["structure"]["use_gated_activation"] = "false"
config["structure"]["position_embedding_type"] = "learned_absolute"
config["structure"]["model_type"] = args.model_type
def parse_bart_config_by_component(config, component, args):
assert component in ('encoder', 'decoder'), 'Unsupported component!'
component_config = types.SimpleNamespace()
component_config = copy_args_to_component_config(component_config, args)
component_config.n_layer = config.getint(component,
f'{component}_layers')
component_config.n_head = config.getint(component,
f'{component}_attention_heads')
component_config.hidden_size = config.getint(component, 'd_model')
component_config.head_size = config.getint(
component,
'd_kv',
fallback=component_config.hidden_size // component_config.n_head)
component_config.ffn_hidden_size = config.getint(
component, f'{component}_ffn_dim')
component_config.vocab_size = config.getint(component, 'vocab_size')
component_config.n_positions = config.getint(component,
'max_position_embeddings')
component_config.has_position_embedding = config.getboolean(
component, 'has_position_embedding',
fallback=True) # TODO: hardcoded here
component_config.has_token_type_embedding = config.getboolean(
component, 'has_token_type_embedding', fallback=False)
component_config.has_embedding_layernorm = config.getboolean(
component, 'has_embedding_layernorm', fallback=True)
component_config.has_embedding_scale = config.getboolean(
component, 'scale_embedding')
component_config.q_scaling = config.getfloat(component,
'q_scaling',
fallback=1.0)
component_config.has_attention_qkvo_bias = config.getboolean(
'structure', 't5_with_bias', fallback=True)
component_config.has_mlp_bias = config.getboolean('structure',
't5_with_bias',
fallback=True)
component_config.has_model_final_layernorm = config.getboolean(
component, 'has_model_final_layernorm')
component_config.layernorm_eps = config.getfloat(component,
'layer_norm_epsilon',
fallback=False)
normalize_before = config.getboolean(component, 'normalize_before')
component_config.layernorm_position = layernorm_position_map[
'pre_layernorm' if normalize_before else 'post_layernorm']
component_config.layernorm_type = layernorm_type_map[config.get(
component, 'layernorm_type', fallback='LayerNorm')]
component_config.hidden_act = config.get(component,
'activation_function')
component_config.gated_act = config.getboolean(component,
'is_gated_act',
fallback=False)
component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.
gated_act else 'MLP']
component_config.relative_attention = config.get(
'structure', 'position_embedding_type') == 'relative'
component_config.num_buckets = config.getint(
component, 'relative_attention_num_buckets', fallback=0)
component_config.max_distance = config.getint(
component, 'relative_attention_max_distance', fallback=0)
component_config.max_lora_rank = config.getint(component,
'max_lora_rank',
fallback=0)
component_config.lora_target_modules = literal_eval(
config.get(component, 'lora_target_modules', fallback="[]"))
component_config.hf_modules_to_trtllm_modules = literal_eval(
config.get(component, 'hf_modules_to_trtllm_modules',
fallback="{}"))
component_config.trtllm_modules_to_hf_modules = literal_eval(
config.get(component, 'trtllm_modules_to_hf_modules',
fallback="{}"))
component_config.logits_dtype = config.get(component,
'logits_dtype',
fallback='float32')
component_config.position_embedding_type = config.get(
'structure', 'position_embedding_type')
if component == 'decoder':
component_config.rescale_before_lm_head = config.getboolean(
component, 'rescale_before_lm_head')
component_config.encoder_hidden_size = config.getint(
'encoder', 'd_model')
component_config.encoder_num_heads = config.getint(
'encoder', 'encoder_attention_heads')
component_config.encoder_head_size = config.getint(
'encoder',
'd_kv',
fallback=component_config.encoder_hidden_size //
component_config.encoder_num_heads)
# nougat has decoder_start_token_id = None, special handling
decoder_start_token_id = config.get('decoder',
'decoder_start_token_id')
component_config.decoder_start_token_id = int(
decoder_start_token_id
) if decoder_start_token_id != "None" else None
component_config.eos_token_id = config.getint(
'decoder', 'eos_token_id')
component_config.bos_token_id = config.getint(
'decoder', 'bos_token_id')
component_config.pad_token_id = config.getint(
'decoder', 'pad_token_id')
return component_config
encoder_config = None
if not args.nougat:
encoder_config = parse_bart_config_by_component(config, "encoder", args)
decoder_config = parse_bart_config_by_component(config, "decoder", args)
return encoder_config, decoder_config
def convert_bart_weights_to_tllm_safetensors(config, component, params):
weights = {}
mapping = config.mapping
hidden_size = config.hidden_size
convert_weight_to_dtype(params, config.dtype)
ffn_hidden_size = config.intermediate_size
vocab_size = config.vocab_size
hf_param_prefix = f'model.{component}'
trtllm_layer_name = f'{component}_layers'
trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention'
trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm'
embedding_layer_names = {
'embed_tokens.weight': {
"name": 'embedding.vocab_embedding.weight',
"shape": (vocab_size, -1)
},
'embed_positions.weight': {
"name": 'embedding.position_embedding.weight',
"shape": (config.max_position_embeddings, hidden_size)
},
'layernorm_embedding.weight': {
"name": 'embedding.embedding_layernorm.weight',
"shape": None
},
'layernorm_embedding.bias': {
"name": 'embedding.embedding_layernorm.bias',
"shape": None
},
}
hidden_layer_name_split = {
'self_attn.out_proj.weight': {
"name": f'{trtllm_attn_layer_name}.dense.weight',
"shape": (hidden_size, hidden_size // mapping.tp_size),
"split_dim": -1
},
'fc1.weight': {
"name": 'mlp.fc.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
'fc1.bias': {
"name": 'mlp.fc.bias',
"shape": (ffn_hidden_size // mapping.tp_size),
"split_dim": 0
},
'fc2.weight': {
"name": 'mlp.proj.weight',
"shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
"split_dim": -1
},
}
hidden_layer_name_no_split = {
'self_attn.out_proj.bias': {
"name": f'{trtllm_attn_layer_name}.dense.bias',
"shape": (hidden_size)
},
'self_attn_layer_norm.weight': {
"name": f'{trtllm_attn_layernorm_name}.weight',
"shape": None
},
'self_attn_layer_norm.bias': {
"name": f'{trtllm_attn_layernorm_name}.bias',
"shape": None
},
'fc2.bias': {
"name": 'mlp.proj.bias',
"shape": (hidden_size)
},
'final_layer_norm.weight': {
"name": 'mlp_layernorm.weight',
"shape": None
},
'final_layer_norm.bias': {
"name": 'mlp_layernorm.bias',
"shape": None
},
}
if config.model_type == 'mbart':
hidden_layer_name_split['layer_norm.weight'] = {
"name": 'final_layernorm.weight',
"shape": None,
"split_dim": 0
}
hidden_layer_name_no_split['layer_norm.bias'] = {
"name": 'final_layernorm.bias',
"shape": None,
"split_dim": 0
}
if component == "decoder":
hidden_layer_name_split.update({
'encoder_attn.out_proj.weight': {
"name": 'cross_attention.dense.weight',
"shape": (hidden_size, hidden_size // mapping.tp_size),
"split_dim": -1
}
})
hidden_layer_name_no_split.update({
'encoder_attn.out_proj.bias': {
"name": 'cross_attention.dense.bias',
"shape": (hidden_size)
},
'encoder_attn_layer_norm.weight': {
"name": 'cross_attention_layernorm.weight',
"shape": None
},
'encoder_attn_layer_norm.bias': {
"name": 'cross_attention_layernorm.bias',
"shape": None
},
})
def get_attn_module_name(component, layer, attn_type):
return f'model.{component}.layers.{int(layer)}.{attn_type}'
for hf_weight_name, weight_info in embedding_layer_names.items():
if 'position' in hf_weight_name:
weights[weight_info["name"]] = params[
f'{hf_param_prefix}.{hf_weight_name}'][2:].clone()
else:
weights[weight_info["name"]] = params[
f'{hf_param_prefix}.{hf_weight_name}'].clone()
weights[weight_info["name"]] = reshape(weights[weight_info["name"]],
weight_info["shape"])
num_layers = config.num_hidden_layers
layers_range = mapping.pp_layers(num_layers)
for layer_idx in layers_range:
local_layer_idx = layer_idx - layers_range[0]
hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}'
trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
for hf_weight_name, weight_info in hidden_layer_name_split.items():
weights[
f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape(
split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'],
mapping.tp_size,
mapping.tp_rank,
dim=weight_info["split_dim"]), weight_info["shape"])
for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}'
hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}'
weights[trtllm_layer_fullname] = reshape(
params[hf_layer_fullname].clone(), shape=weight_info["shape"])
self_attn_module_name = get_attn_module_name(component, layer_idx,
'self_attn')
weights.update(
fuse_qkv_one_layer(
params, self_attn_module_name,
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
mapping.tp_size, mapping.tp_rank, config.model_type,
(hidden_size * 3 // mapping.tp_size, hidden_size),
(hidden_size * 3 // mapping.tp_size)))
if component == 'decoder':
cross_attn_module_name = get_attn_module_name(
component, layer_idx, 'encoder_attn')
weights.update(
fuse_qkv_one_layer(
params, cross_attn_module_name,
f'{trtllm_layer_name_prefix}.cross_attention',
mapping.tp_size, mapping.tp_rank, config.model_type,
(hidden_size * 3 // mapping.tp_size, hidden_size),
(hidden_size * 3 // mapping.tp_size)))
if component == 'decoder':
weights['lm_head.weight'] = reshape(
split(params['lm_head.weight'],
mapping.tp_size,
mapping.tp_rank,
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
if config.has_model_final_layernorm:
weights['final_layernorm.weight'] = params[
f'{hf_param_prefix}.layer_norm.weight'].clone()
weights['final_layernorm.bias'] = params[
f'{hf_param_prefix}.layer_norm.bias'].clone()
return weights
def parse_pix2struct_config(args, hf_model):
# manually set q_scaling to offset attention scaling's effect.
# TODO: modify kernels to control whether to disable attention scaling
config = configparser.ConfigParser()
def get_offset_q_scaling(config) -> str:
d_model = config.hidden_size
num_heads = config.num_heads
head_size = d_model / num_heads
scaling = 1 / head_size**.5
return str(scaling)
config["decoder"] = {}
for key, val in hf_model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
config["decoder"]["q_scaling"] = get_offset_q_scaling(
hf_model.decoder.config)
config["structure"] = dict()
config["structure"]["pix2struct_with_bias"] = "false"
config["structure"]["use_gated_activation"] = "false"
config["structure"]["position_embedding_type"] = "relative"
config["structure"]["model_type"] = args.model_type
def parse_pix2struct_config_by_component(config, component, args):
if component == 'decoder':
args.n_layer = config.getint(component, 'num_layers')
args.n_head = config.getint(component, 'num_heads')
args.head_size = config.getint(component, 'd_kv')
args.hidden_size = config.getint(component, 'hidden_size')
args.ffn_hidden_size = config.getint(component, 'd_ff')
args.vocab_size = config.getint(component, 'vocab_size')
args.n_positions = config.getint(component,
'n_positions',
fallback=512)
args.has_position_embedding = config.getboolean(
component, 'has_position_embedding',
fallback=False) # TODO: hardcoded here
args.has_token_type_embedding = config.getboolean(
component, 'has_token_type_embedding', fallback=False)
args.has_embedding_layernorm = config.getboolean(
component, 'has_embedding_layernorm', fallback=False)
args.has_embedding_scale = config.getboolean(component,
'has_embedding_scale',
fallback=False)
args.q_scaling = config.getfloat(component,
'q_scaling',
fallback=1.0)
args.has_attention_qkvo_bias = config.getboolean(
component, 'has_attention_qkvo_bias', fallback=False)
args.has_mlp_bias = config.getboolean(component,
'has_mlp_bias',
fallback=False)
args.has_model_final_layernorm = config.getboolean(
component, 'has_model_final_layernorm', fallback=True)
args.layernorm_eps = config.getfloat(component,
'layer_norm_epsilon')
args.layernorm_position = layernorm_position_map[config.get(
component, 'layernorm_position',
fallback='pre_layernorm')] # TODO: hardcoded here
args.layernorm_type = layernorm_type_map[config.get(
component, 'layernorm_type', fallback='RmsNorm')]
args.hidden_act = config.get(component, 'dense_act_fn')
args.gated_act = True
args.mlp_type = mlp_type_map['GatedMLP' if args.
gated_act else 'MLP']
args.has_lm_head_bias = config.getboolean(
component, # TODO: T5 with bias
'has_lm_head_bias',
fallback=False)
args.relative_attention = config.getboolean(component,
'relative_attention',
fallback=True)
args.num_buckets = config.getint(component,
'relative_attention_num_buckets')
args.max_distance = config.getint(
component, 'relative_attention_max_distance')
args.logits_dtype = config.get(component,
'logits_dtype',
fallback='float32')
args.rescale_before_lm_head = config.getboolean(
component, 'tie_word_embeddings'
) # default is True (for T5), but False for Flan-T5
args.encoder_hidden_size = config.getint('decoder', 'hidden_size')
args.encoder_num_heads = config.getint('decoder', 'num_heads')
args.encoder_head_size = config.getint('decoder', 'd_kv')
args.position_embedding_type = config.get(
'structure', 'position_embedding_type')
args.decoder_start_token_id = config.getint(
'decoder', 'decoder_start_token_id')
args.eos_token_id = config.getint('decoder', 'eos_token_id')
bos_token_id = config.get('decoder', 'bos_token_id')
# pix2struct does not have bos_token_id
args.bos_token_id = int(
bos_token_id) if bos_token_id != "None" else None
args.pad_token_id = config.getint('decoder', 'pad_token_id')
else:
assert False, 'Unsupported component!'
return args
decoder_args = parse_pix2struct_config_by_component(config, "decoder", args)
return None, decoder_args
def convert_pix2struct_weights_to_tllm_safetensors(config, component, params):
weights = {}
mapping = config.mapping
convert_weight_to_dtype(params, config.dtype)
hidden_size = config.hidden_size
ffn_hidden_size = config.intermediate_size
num_layers = config.num_hidden_layers
n_head = config.num_attention_heads
head_size = config.head_size
attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5
hf_param_prefix = f'{component}'
trtllm_layer_name = f'{component}_layers'
trtllm_attn_layer_name = 'self_attention'
trtllm_attn_layernorm_name = 'self_attention_layernorm'
def get_attn_module_name(component, layer, attn_type):
return f'{component}.layer.{int(layer)}.{attn_type}.attention'
weights['embedding.vocab_embedding.weight'] = reshape(
params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None)
layers_range = mapping.pp_layers(num_layers)
for layer_idx in layers_range:
local_layer_idx = layer_idx - layers_range[0]
trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}'
hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}'
hidden_layer_name_split = {
f'{hf_layer_name_prefix}.self_attention.attention.output.weight': {
"name":
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight',
"shape":
(hidden_size, attention_hidden_size // mapping.tp_size),
"split_dim": -1
},
f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': {
"name": f'{trtllm_layer_name_prefix}.mlp.proj.weight',
"shape": (hidden_size, ffn_hidden_size // mapping.tp_size),
"split_dim": -1
},
f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': {
"name": f'{trtllm_layer_name_prefix}.mlp.fc.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
}
hidden_layer_name_no_split = {
f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': {
"name":
f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight',
"shape": None
},
f'{hf_layer_name_prefix}.mlp.layer_norm.weight': {
"name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight',
"shape": None
},
}
if config.gated_act:
hidden_layer_name_split.update({
f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': {
"name": f'{trtllm_layer_name_prefix}.mlp.gate.weight',
"shape": (ffn_hidden_size // mapping.tp_size, hidden_size),
"split_dim": 0
},
})
hidden_layer_name_split.update({
f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight':
{
"name":
f'{trtllm_layer_name_prefix}.cross_attention.dense.weight',
"shape":
(hidden_size, attention_hidden_size // mapping.tp_size),
"split_dim": -1
},
})
hidden_layer_name_no_split.update({
f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight':
{
"name":
f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight',
"shape": None
},
})
self_attn_module_name = get_attn_module_name(
component, layer_idx, 'encoder_decoder_attention')
weights.update(
fuse_qkv_one_layer(
params, self_attn_module_name,
f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size,
mapping.tp_rank, config.model_type,
(attention_hidden_size * 3 // mapping.tp_size, hidden_size),
None))
self_attn_module_name = get_attn_module_name(component, layer_idx,
'self_attention')
weights.update(
fuse_qkv_one_layer(
params, self_attn_module_name,
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}',
mapping.tp_size, mapping.tp_rank, config.model_type,
(attention_hidden_size * 3 // mapping.tp_size, hidden_size),
None))
weights[
f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape(
split(
params[
f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
.T, mapping.tp_size, mapping.tp_rank, 0),
(n_head // mapping.tp_size, config.num_buckets))
for hf_weight_name, weight_info in hidden_layer_name_split.items():
if hf_weight_name in params.keys():
weights[weight_info["name"]] = reshape(
split(params[hf_weight_name],
mapping.tp_size,
mapping.tp_rank,
dim=weight_info["split_dim"]), weight_info["shape"])
for hf_weight_name, weight_info in hidden_layer_name_no_split.items():
if hf_weight_name in params.keys():
weights[weight_info["name"]] = reshape(
params[hf_weight_name].clone(), shape=weight_info["shape"])
weights[f'final_layernorm.weight'] = reshape(
params[f'{component}.final_layer_norm.weight'].clone(), None)
weights['lm_head.weight'] = reshape(
split(params[f'{component}.lm_head.weight'],
mapping.tp_size,
mapping.tp_rank,
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
if not config.use_implicit_relative_attention:
weights[f'rel_attn_table'] = reshape(
split(
params[
f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight']
.T, mapping.tp_size, mapping.tp_rank, 0),
(n_head // mapping.tp_size, config.num_buckets))
return weights
def get_model(args):
if args.model_type == "t5":
model = T5ForConditionalGeneration.from_pretrained(args.model_dir)
elif args.model_type == "nmt":
from fairseq.models.transformer import TransformerModel
model = TransformerModel.from_pretrained(args.model_dir)
elif args.model_type == "bart":
if args.nougat:
model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
model = model.get_decoder()
else:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)
elif args.model_type == "pix2struct":
model = Pix2StructForConditionalGeneration.from_pretrained(
args.model_dir)
elif args.model_type == "blip2":
model = Blip2ForConditionalGeneration.from_pretrained(
args.model_dir).language_model
return model
def convert_checkpoint(args):
model = get_model(args)
saved_dir = Path(args.output_dir)
saved_dir.mkdir(parents=True, exist_ok=True)
encoder_saved_dir = saved_dir / "encoder"
encoder_saved_dir.mkdir(parents=True, exist_ok=True)
decoder_saved_dir = saved_dir / "decoder"
decoder_saved_dir.mkdir(parents=True, exist_ok=True)
world_size = args.tp_size * args.pp_size
kv_cache_quant_algo = None
quant_algo = None
model_type = args.model_type if args.model_type != "blip2" else "t5"
encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
args, model)
additional_settings = ["gated_act"]
if not args.nougat and args.model_type != "pix2struct":
tllm_encoder_config = {
'architecture': "EncoderModel",
'dtype': args.dtype,
'logits_dtype': encoder_config.logits_dtype,
'num_hidden_layers': encoder_config.n_layer,
'num_attention_heads': encoder_config.n_head,
'hidden_size': encoder_config.hidden_size,
'norm_epsilon': encoder_config.layernorm_eps,
'vocab_size': encoder_config.vocab_size,
'position_embedding_type': encoder_config.position_embedding_type,
'hidden_act': encoder_config.hidden_act,
'quantization': {
'quant_algo': quant_algo,
'kv_cache_quant_algo': kv_cache_quant_algo,
},
'mapping': {
'world_size': world_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
},
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'max_position_embeddings': encoder_config.n_positions,
'num_key_value_heads': encoder_config.n_head,
'head_size': encoder_config.head_size,
'has_position_embedding': encoder_config.has_position_embedding,
'layernorm_type': encoder_config.layernorm_type,
'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias,
'has_mlp_bias': encoder_config.has_mlp_bias,
'has_model_final_layernorm':
encoder_config.has_model_final_layernorm,
'has_embedding_layernorm': encoder_config.has_embedding_layernorm,
'has_embedding_scale': encoder_config.has_embedding_scale,
'intermediate_size': encoder_config.ffn_hidden_size,
'q_scaling': encoder_config.q_scaling,
'layernorm_position': encoder_config.layernorm_position,
'mlp_type': encoder_config.mlp_type,
'relative_attention': encoder_config.relative_attention,
'max_distance': encoder_config.max_distance,
'num_buckets': encoder_config.num_buckets,
'model_type': encoder_config.model_type,
}
for additional_setting in additional_settings:
if hasattr(encoder_config, additional_setting):
tllm_encoder_config.update({
additional_setting:
getattr(encoder_config, additional_setting)
})
with (encoder_saved_dir / "config.json").open('w') as f:
json.dump(tllm_encoder_config, f, indent=4)
encoder_convert_args = dict(params=model.state_dict(),
component="encoder")
tllm_decoder_config = {
'architecture': "DecoderModel",
'dtype': args.dtype,
'logits_dtype': decoder_config.logits_dtype,
'num_hidden_layers': decoder_config.n_layer,
'num_attention_heads': decoder_config.n_head,
'hidden_size': decoder_config.hidden_size,
'norm_epsilon': decoder_config.layernorm_eps,
'vocab_size': decoder_config.vocab_size,
'position_embedding_type': decoder_config.position_embedding_type,
'hidden_act': decoder_config.hidden_act,
'quantization': {
'quant_algo': quant_algo,
'kv_cache_quant_algo': kv_cache_quant_algo,
},
'mapping': {
'world_size': world_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
},
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'max_position_embeddings': decoder_config.n_positions,
'head_size': decoder_config.head_size,
'has_position_embedding': decoder_config.has_position_embedding,
'layernorm_type': decoder_config.layernorm_type,
'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias,
'has_mlp_bias': decoder_config.has_mlp_bias,
'has_model_final_layernorm': decoder_config.has_model_final_layernorm,
'has_embedding_layernorm': decoder_config.has_embedding_layernorm,
'has_embedding_scale': decoder_config.has_embedding_scale,
'intermediate_size': decoder_config.ffn_hidden_size,
'q_scaling': decoder_config.q_scaling,
'layernorm_position': decoder_config.layernorm_position,
'mlp_type': decoder_config.mlp_type,
'relative_attention': decoder_config.relative_attention,
'max_distance': decoder_config.max_distance,
'num_buckets': decoder_config.num_buckets,
'model_type': decoder_config.model_type,
'rescale_before_lm_head': decoder_config.rescale_before_lm_head,
'encoder_hidden_size': decoder_config.encoder_hidden_size,
'encoder_num_heads': decoder_config.encoder_num_heads,
'encoder_head_size': decoder_config.encoder_head_size,
'skip_cross_kv': args.skip_cross_kv,
'use_implicit_relative_attention': args.use_implicit_relative_attention,
'decoder_start_token_id': decoder_config.decoder_start_token_id,
'eos_token_id': decoder_config.eos_token_id,
'bos_token_id': decoder_config.bos_token_id,
'pad_token_id': decoder_config.pad_token_id,
}
for additional_setting in additional_settings:
if hasattr(decoder_config, additional_setting):
tllm_decoder_config.update({
additional_setting:
getattr(decoder_config, additional_setting)
})
with (decoder_saved_dir / "config.json").open('w') as f:
json.dump(tllm_decoder_config, f, indent=4)
decoder_convert_args = dict(params=model.state_dict(), component="decoder")
if args.model_type == "nmt":
fairseq_config = vars(model.cfg.model) # Namespace --> dict
num_embeddings = fairseq_config['max_source_positions']
embedding_dim = fairseq_config['encoder_embed_dim']
padding_idx = model.models[0].encoder.embed_tokens.padding_idx # 1
sin_pos_embedding = model.models[
0].encoder.embed_positions.get_embedding(
padding_idx + 1 + num_embeddings,
embedding_dim,
padding_idx=padding_idx) # [2 + num_embeddings, embed_dim]
sin_pos_embedding = sin_pos_embedding[2:, :] # remove offset embeddings
encoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding
if args.workers == 1:
if not args.nougat and args.model_type != "pix2struct":
convert(0, world_size, args, tllm_encoder_config,
encoder_convert_args, encoder_saved_dir)
convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,
decoder_saved_dir)
else:
if args.workers > world_size:
args.workers = world_size
LOGGER.info(f'Convert checkpoint using {args.workers} workers.')
import torch.multiprocessing as mp
if not args.nougat and args.model_type != "pix2struct":
mp.spawn(convert,
nprocs=args.workers,
args=(world_size, args, tllm_encoder_config,
encoder_convert_args, encoder_saved_dir))
mp.spawn(convert,
nprocs=args.workers,
args=(world_size, args, tllm_decoder_config,
decoder_convert_args, decoder_saved_dir))
def convert(worker_rank, world_size, args, model_config, convert_args,
saved_dir):
for rank in range(worker_rank, world_size, args.workers):
rank_config = copy.deepcopy(PretrainedConfig.from_dict(model_config))
rank_config.set_rank(rank)
weights = globals(
)[f'convert_{rank_config.model_type}_weights_to_tllm_safetensors'](
config=rank_config, **convert_args)
safetensors.torch.save_file(weights,
f'{saved_dir}/rank{rank}.safetensors')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
'--model_type',
type=str,
default='t5',
choices=['t5', 'nmt', 'bart', 'pix2struct', 'blip2'],
help=
'Multimodal type when this script is used for multimodal conversion.')
parser.add_argument('--tp_size',
type=int,
default=1,
help='N-way tensor parallelism size')
parser.add_argument('--pp_size',
type=int,
default=1,
help='N-way pipeline parallelism size')
parser.add_argument("--model_dir",
"-i",
type=str,
help="Path to the framework checkpoint file",
required=True)
parser.add_argument("--output_dir",
"-o",
type=str,
help="Path to the converted TRT-LLM model weight file",
required=True)
parser.add_argument(
"--workers",
type=int,
help="How many workers to spawn for conversion (default: 4)",
default=4)
parser.add_argument("--nougat",
action="store_true",
help="Model which uses vision encoder + mbart decoder")
parser.add_argument("--verbose",
action="store_true",
help="Provide verbose messages")
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharding is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float16', 'float32', 'bfloat16'],
help=
'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.'
)
parser.add_argument(
'--skip_cross_kv',
action='store_true',
help=
'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).'
)
parser.add_argument(
'--use_implicit_relative_attention',
action='store_true',
help=
'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.'
)
args = parser.parse_args()
log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s"
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO,
format=log_format)
LOGGER.info("\n=============== Argument ===============")
for key in vars(args):
LOGGER.info(f"{key}: {vars(args)[key]}")
LOGGER.info("========================================")
start_time = datetime.now()
convert_checkpoint(args)
stop_time = datetime.now()
run_time = (stop_time - start_time)
LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time))