# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/convert_checkpoint.py import argparse import configparser import copy import json import logging import os import types from ast import literal_eval from datetime import datetime from pathlib import Path import safetensors from helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration, MBartForConditionalGeneration, Pix2StructForConditionalGeneration, T5ForConditionalGeneration, VisionEncoderDecoderModel) from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, MLPType) from tensorrt_llm.models import PretrainedConfig dir_path = os.path.dirname(os.path.realpath(__file__)) LOGGER = logging.getLogger(__name__) layernorm_type_map = {i.name: i.value for i in LayerNormType} layernorm_position_map = {i.name: i.value for i in LayerNormPositionType} mlp_type_map = {i.name: i.value for i in MLPType} def copy_args_to_component_config(component_config, args): for arg in vars(args): setattr(component_config, arg, getattr(args, arg)) return component_config def parse_t5_config(args, hf_model): config = configparser.ConfigParser() config["encoder"] = {} for key, val in hf_model.encoder.config.to_dict().items(): config["encoder"][key] = f"{val}" # manually set q_scaling to offset attention scaling's effect. # TODO: modify kernels to control whether to disable attention scaling def get_offset_q_scaling(config): scaling = 1 / config.head_size**.5 return scaling config["decoder"] = {} for key, val in hf_model.decoder.config.to_dict().items(): config["decoder"][key] = f"{val}" config["structure"] = dict() config["structure"]["t5_with_bias"] = "false" config["structure"]["use_gated_activation"] = str( hf_model.encoder.config.is_gated_act) config["structure"]["position_embedding_type"] = "relative" config["structure"]["model_type"] = args.model_type def parse_t5_config_by_component(config, component, args): component_config = types.SimpleNamespace() component_config = copy_args_to_component_config(component_config, args) component_config.n_head = config.getint(component, 'num_heads') component_config.head_size = config.getint(component, 'd_kv') component_config.hidden_size = config.getint(component, 'd_model') component_config.ffn_hidden_size = config.getint(component, 'd_ff') component_config.vocab_size = config.getint(component, 'vocab_size') component_config.n_positions = config.getint(component, 'n_positions', fallback=512) component_config.has_position_embedding = config.getboolean( component, 'has_position_embedding', fallback=False) # TODO: hardcoded here component_config.has_token_type_embedding = config.getboolean( component, 'has_token_type_embedding', fallback=False) component_config.has_embedding_layernorm = config.getboolean( component, 'has_embedding_layernorm', fallback=False) component_config.has_embedding_scale = config.getboolean( component, 'has_embedding_scale', fallback=False) component_config.q_scaling = get_offset_q_scaling(component_config) component_config.has_attention_qkvo_bias = config.getboolean( component, 'has_attention_qkvo_bias', fallback=False) # TODO: hardcoded here component_config.has_mlp_bias = config.getboolean(component, 'has_mlp_bias', fallback=False) component_config.has_model_final_layernorm = config.getboolean( component, 'has_model_final_layernorm', fallback=True) component_config.layernorm_eps = config.getfloat( component, 'layer_norm_epsilon') component_config.layernorm_position = layernorm_position_map[config.get( component, 'layernorm_position', fallback='pre_layernorm')] # TODO: hardcoded here component_config.layernorm_type = layernorm_type_map[config.get( component, 'layernorm_type', fallback='RmsNorm')] component_config.hidden_act = config.get(component, 'dense_act_fn') component_config.gated_act = config.getboolean(component, 'is_gated_act') component_config.mlp_type = mlp_type_map['GatedMLP' if component_config. gated_act else 'MLP'] component_config.num_buckets = config.getint( component, 'relative_attention_num_buckets') component_config.max_distance = config.getint( component, 'relative_attention_max_distance') component_config.position_embedding_type = config.get( 'structure', 'position_embedding_type') component_config.logits_dtype = config.get(component, 'logits_dtype', fallback='float32') if component == 'encoder': component_config.n_layer = config.getint(component, 'num_layers') component_config.relative_attention = config.get( 'structure', 'position_embedding_type') == 'relative' elif component == 'decoder': component_config.n_layer = config.getint(component, 'num_decoder_layers') component_config.has_lm_head_bias = config.getboolean( component, # TODO: T5 with bias 'has_lm_head_bias', fallback=False) component_config.relative_attention = config.getboolean( component, 'relative_attention', fallback=True) component_config.rescale_before_lm_head = config.getboolean( component, 'tie_word_embeddings' ) # default is True (for T5), but False for Flan-T5 component_config.encoder_hidden_size = config.getint( 'encoder', 'd_model') component_config.encoder_num_heads = config.getint( 'encoder', 'num_heads') component_config.encoder_head_size = config.getint( 'encoder', 'd_kv') component_config.decoder_start_token_id = config.getint( 'decoder', 'decoder_start_token_id') component_config.eos_token_id = config.getint( 'decoder', 'eos_token_id') bos_token_id = config.get('decoder', 'bos_token_id') # T5 does not have bos_token_id component_config.bos_token_id = int( bos_token_id) if bos_token_id != "None" else None component_config.pad_token_id = config.getint( 'decoder', 'pad_token_id') else: assert False, 'Unsupported component!' return component_config encoder_config = parse_t5_config_by_component(config, "encoder", args) decoder_config = parse_t5_config_by_component(config, "decoder", args) return encoder_config, decoder_config def convert_t5_weights_to_tllm_safetensors(config, component, params): weights = {} mapping = config.mapping convert_weight_to_dtype(params, config.dtype) hidden_size = config.hidden_size ffn_hidden_size = config.intermediate_size num_layers = config.num_hidden_layers n_head = config.num_attention_heads head_size = config.head_size attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5 hf_param_prefix = f'{component}' trtllm_layer_name = f'{component}_layers' trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention' trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm' hf_component_idx = 1 if component == 'encoder' else 2 def get_attn_module_name(component, block, layer, attn_type): return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}' weights['embedding.vocab_embedding.weight'] = reshape( params['shared.weight'].clone(), None) layers_range = mapping.pp_layers(num_layers) for layer_idx in layers_range: local_layer_idx = layer_idx - layers_range[0] trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}' hidden_layer_name_split = { f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': { "name": f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight', "shape": (hidden_size, attention_hidden_size // mapping.tp_size), "split_dim": -1 }, f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight', "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), "split_dim": -1 }, f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, } hidden_layer_name_no_split = { f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': { "name": f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight', "shape": None }, f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight': { "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight', "shape": None }, } if config.gated_act: hidden_layer_name_split.update({ f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, }) if component == 'decoder': hidden_layer_name_split.update({ f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': { "name": f'{trtllm_layer_name_prefix}.cross_attention.dense.weight', "shape": (hidden_size, attention_hidden_size // mapping.tp_size), "split_dim": -1 }, }) hidden_layer_name_no_split.update({ f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': { "name": f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight', "shape": None }, }) self_attn_module_name = get_attn_module_name( component, layer_idx, "1", 'EncDecAttention') weights.update( fuse_qkv_one_layer( params, self_attn_module_name, f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size, mapping.tp_rank, config.model_type, (attention_hidden_size * 3 // mapping.tp_size, hidden_size), None)) self_attn_module_name = get_attn_module_name(component, layer_idx, "0", 'SelfAttention') weights.update( fuse_qkv_one_layer( params, self_attn_module_name, f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', mapping.tp_size, mapping.tp_rank, config.model_type, (attention_hidden_size * 3 // mapping.tp_size, hidden_size), None)) weights[ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape( split( params[ f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight'] .T, mapping.tp_size, mapping.tp_rank, 0), (n_head // mapping.tp_size, config.num_buckets)) for hf_weight_name, weight_info in hidden_layer_name_split.items(): if hf_weight_name in params.keys(): weights[weight_info["name"]] = reshape( split(params[hf_weight_name], mapping.tp_size, mapping.tp_rank, dim=weight_info["split_dim"]), weight_info["shape"]) for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): if hf_weight_name in params.keys(): weights[weight_info["name"]] = reshape( params[hf_weight_name].clone(), shape=weight_info["shape"]) weights['final_layernorm.weight'] = reshape( params[f'{component}.final_layer_norm.weight'].clone(), None) if component == 'decoder': weights['lm_head.weight'] = reshape( split(params['lm_head.weight'], mapping.tp_size, mapping.tp_rank, dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) if not config.use_implicit_relative_attention: weights['rel_attn_table'] = reshape( split( params[ f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight'] .T, mapping.tp_size, mapping.tp_rank, 0), (n_head // mapping.tp_size, config.num_buckets)) return weights convert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors # func alias def parse_nmt_config(args, model): config = configparser.ConfigParser() fairseq_config = vars(model.cfg.model) # Namespace --> dict config['encoder'] = dict() for key, val in fairseq_config.items(): config["encoder"][key] = f"{val}" config["encoder"]["q_scaling"] = '1' # NMT has final layernorm for pre-norm model architecture. config['encoder']['has_model_final_layernorm'] = config['encoder'][ 'encoder_normalize_before'] config['encoder']['vocab_size'] = str(len(model.src_dict)) # fairseq naming config['decoder'] = dict() for key, val in fairseq_config.items(): config["decoder"][key] = f"{val}" config["decoder"]["q_scaling"] = '1' config["decoder"]["rescale_before_lm_head"] = 'false' config['decoder']['has_model_final_layernorm'] = str( config['decoder'].getboolean('decoder_normalize_before', False) and not config['decoder'].getboolean('no_decoder_final_norm', False)) config['decoder']['vocab_size'] = str(len(model.tgt_dict)) # fairseq naming config["structure"] = dict() config["structure"]["t5_with_bias"] = "true" config["structure"]["use_gated_activation"] = "false" config["structure"][ "position_embedding_type"] = "learned_absolute" # "sinusoid" config["structure"]["model_type"] = args.model_type def parse_nmt_config_by_component(config, component, args): assert component in ('encoder', 'decoder'), 'Unsupported component!' component_config = types.SimpleNamespace() component_config = copy_args_to_component_config(component_config, args) component_config.n_layer = config.getint(component, f'{component}_layers') component_config.n_head = config.getint(component, f'{component}_attention_heads') component_config.hidden_size = config.getint( component, f'{component}_embed_dim') # fairseq naming component_config.head_size = config.getint( component, 'd_kv', fallback=component_config.hidden_size // component_config.n_head) component_config.ffn_hidden_size = config.getint( component, f'{component}_ffn_embed_dim') # fairseq naming component_config.vocab_size = config.getint(component, 'vocab_size') component_config.n_positions = config.getint( component, 'max_source_positions') # fairseq naming component_config.has_position_embedding = not config.getboolean( component, 'no_token_positional_embeddings', fallback=False) # fairseq naming component_config.has_token_type_embedding = config.getboolean( component, 'has_token_type_embedding', fallback=False) component_config.has_embedding_layernorm = config.getboolean( component, 'layernorm_embedding', fallback=True) # fairseq naming component_config.has_embedding_scale = not config.getboolean( component, 'no_scale_embedding') # fairseq naming component_config.q_scaling = config.getfloat(component, 'q_scaling', fallback=1.0) component_config.has_attention_qkvo_bias = config.getboolean( 'structure', 't5_with_bias', fallback=True) component_config.has_mlp_bias = config.getboolean('structure', 't5_with_bias', fallback=True) component_config.has_model_final_layernorm = config.getboolean( component, 'has_model_final_layernorm') component_config.layernorm_eps = config.getfloat( component, 'layer_norm_epsilon', fallback=1e-5) # fairseq naming normalize_before = config.getboolean( component, f'{component}_normalize_before') # fairseq naming component_config.layernorm_position = layernorm_position_map[ 'pre_layernorm' if normalize_before else 'post_layernorm'] component_config.layernorm_type = layernorm_type_map[config.get( component, 'layernorm_type', fallback='LayerNorm')] component_config.hidden_act = config.get( component, 'activation_fn') # fairseq naming component_config.gated_act = config.getboolean(component, 'is_gated_act', fallback=False) component_config.mlp_type = mlp_type_map['GatedMLP' if component_config. gated_act else 'MLP'] component_config.relative_attention = config.get( 'structure', 'position_embedding_type') == 'relative' component_config.num_buckets = config.getint( component, 'relative_attention_num_buckets', fallback=0) component_config.max_distance = config.getint( component, 'relative_attention_max_distance', fallback=0) component_config.position_embedding_type = config.get( 'structure', 'position_embedding_type') component_config.logits_dtype = config.get(component, 'logits_dtype', fallback='float32') if component == 'decoder': component_config.rescale_before_lm_head = config.getboolean( component, 'rescale_before_lm_head') component_config.encoder_hidden_size = config.getint( 'encoder', 'encoder_embed_dim') # fairseq naming component_config.encoder_num_heads = config.getint( 'encoder', 'encoder_attention_heads') component_config.encoder_head_size = config.getint( 'encoder', 'd_kv', fallback=component_config.encoder_hidden_size // component_config.encoder_num_heads) component_config.decoder_start_token_id = None component_config.eos_token_id = None component_config.bos_token_id = None component_config.pad_token_id = None return component_config encoder_config = parse_nmt_config_by_component(config, "encoder", args) decoder_config = parse_nmt_config_by_component(config, "decoder", args) return encoder_config, decoder_config def convert_nmt_weights_to_tllm_safetensors(config, component, params, sin_pos_embedding): weights = {} mapping = config.mapping hidden_size = config.hidden_size convert_weight_to_dtype(params, config.dtype) ffn_hidden_size = config.intermediate_size vocab_size = config.vocab_size hf_param_prefix = f'models.0.{component}' trtllm_layer_name = f'{component}_layers' trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention' trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm' hidden_layer_name_split = { 'self_attn.out_proj.weight': { "name": f'{trtllm_attn_layer_name}.dense.weight', "shape": (hidden_size, hidden_size // mapping.tp_size), "split_dim": -1 }, 'fc1.weight': { "name": 'mlp.fc.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, 'fc1.bias': { "name": 'mlp.fc.bias', "shape": (ffn_hidden_size // mapping.tp_size), "split_dim": 0 }, 'fc2.weight': { "name": 'mlp.proj.weight', "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), "split_dim": -1 }, } hidden_layer_name_no_split = { 'self_attn.out_proj.bias': { "name": f'{trtllm_attn_layer_name}.dense.bias', "shape": (hidden_size) }, 'self_attn_layer_norm.weight': { "name": f'{trtllm_attn_layernorm_name}.weight', "shape": None }, 'self_attn_layer_norm.bias': { "name": f'{trtllm_attn_layernorm_name}.bias', "shape": None }, 'fc2.bias': { "name": 'mlp.proj.bias', "shape": (hidden_size) }, 'final_layer_norm.weight': { "name": 'mlp_layernorm.weight', "shape": None }, 'final_layer_norm.bias': { "name": 'mlp_layernorm.bias', "shape": None }, } if component == "decoder": hidden_layer_name_split.update({ 'encoder_attn.out_proj.weight': { "name": 'cross_attention.dense.weight', "shape": (hidden_size, hidden_size // mapping.tp_size), "split_dim": -1 }, }) hidden_layer_name_no_split.update({ 'encoder_attn.out_proj.bias': { "name": 'cross_attention.dense.bias', "shape": (hidden_size) }, 'encoder_attn_layer_norm.weight': { "name": 'cross_attention_layernorm.weight', "shape": None, }, 'encoder_attn_layer_norm.bias': { "name": 'cross_attention_layernorm.bias', "shape": None }, }) def get_attn_module_name(component, layer, attn_type): return f'models.0.{component}.layers.{int(layer)}.{attn_type}' weights["embedding.vocab_embedding.weight"] = reshape( params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), (vocab_size, -1)) weights["embedding.position_embedding.weight"] = reshape( sin_pos_embedding, (config.max_position_embeddings, hidden_size)) num_layers = config.num_hidden_layers layers_range = mapping.pp_layers(num_layers) for layer_idx in layers_range: local_layer_idx = layer_idx - layers_range[0] hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}' trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' for hf_weight_name, weight_info in hidden_layer_name_split.items(): weights[ f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape( split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'], mapping.tp_size, mapping.tp_rank, dim=weight_info["split_dim"]), weight_info["shape"]) for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}' hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}' weights[trtllm_layer_fullname] = reshape( params[hf_layer_fullname].clone(), shape=weight_info["shape"]) self_attn_module_name = get_attn_module_name(component, layer_idx, 'self_attn') weights.update( fuse_qkv_one_layer( params, self_attn_module_name, f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', mapping.tp_size, mapping.tp_rank, config.model_type, (hidden_size * 3 // mapping.tp_size, hidden_size), (hidden_size * 3 // mapping.tp_size))) if component == 'decoder': cross_attn_module_name = get_attn_module_name( component, layer_idx, 'encoder_attn') weights.update( fuse_qkv_one_layer( params, cross_attn_module_name, f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size, mapping.tp_rank, config.model_type, (hidden_size * 3 // mapping.tp_size, hidden_size), (hidden_size * 3 // mapping.tp_size))) if component == 'decoder': weights['lm_head.weight'] = reshape( split(params[f'{hf_param_prefix}.output_projection.weight'], mapping.tp_size, mapping.tp_rank, dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) if config.has_model_final_layernorm: weights['final_layernorm.weight'] = params[ f'{hf_param_prefix}.layer_norm.weight'].clone() weights['final_layernorm.bias'] = params[ f'{hf_param_prefix}.layer_norm.bias'].clone() return weights def parse_bart_config(args, hf_model): config = configparser.ConfigParser() config['decoder'] = dict() for key, val in hf_model.model.decoder.config.to_dict().items(): config["decoder"][key] = f"{val}" config["decoder"]["q_scaling"] = '1' config["decoder"]["rescale_before_lm_head"] = str(False) config['decoder']['has_model_final_layernorm'] = str( args.nougat or isinstance(hf_model, MBartForConditionalGeneration)) if args.nougat: # These flags are true for mbart decoders, but missing in HF config config['decoder']['normalize_before'] = str(True) config['decoder']['normalize_embeddings'] = str(True) config['encoder'] = dict() # Init few encoder configs, needed by build, from decoder config encoder_config_keys = [ "encoder_ffn_dim", "encoder_layers", "encoder_attention_heads", "encoder_layerdrop", "d_model" ] for key in encoder_config_keys: config['encoder'][key] = config['decoder'][key] else: config['encoder'] = dict() for key, val in hf_model.model.encoder.config.to_dict().items(): config["encoder"][key] = f"{val}" config["encoder"]["q_scaling"] = '1' # mBART has final layernorm, BART does not config['encoder']['has_model_final_layernorm'] = str( isinstance(hf_model, MBartForConditionalGeneration)) config["structure"] = dict() config["structure"]["t5_with_bias"] = "true" config["structure"]["use_gated_activation"] = "false" config["structure"]["position_embedding_type"] = "learned_absolute" config["structure"]["model_type"] = args.model_type def parse_bart_config_by_component(config, component, args): assert component in ('encoder', 'decoder'), 'Unsupported component!' component_config = types.SimpleNamespace() component_config = copy_args_to_component_config(component_config, args) component_config.n_layer = config.getint(component, f'{component}_layers') component_config.n_head = config.getint(component, f'{component}_attention_heads') component_config.hidden_size = config.getint(component, 'd_model') component_config.head_size = config.getint( component, 'd_kv', fallback=component_config.hidden_size // component_config.n_head) component_config.ffn_hidden_size = config.getint( component, f'{component}_ffn_dim') component_config.vocab_size = config.getint(component, 'vocab_size') component_config.n_positions = config.getint(component, 'max_position_embeddings') component_config.has_position_embedding = config.getboolean( component, 'has_position_embedding', fallback=True) # TODO: hardcoded here component_config.has_token_type_embedding = config.getboolean( component, 'has_token_type_embedding', fallback=False) component_config.has_embedding_layernorm = config.getboolean( component, 'has_embedding_layernorm', fallback=True) component_config.has_embedding_scale = config.getboolean( component, 'scale_embedding') component_config.q_scaling = config.getfloat(component, 'q_scaling', fallback=1.0) component_config.has_attention_qkvo_bias = config.getboolean( 'structure', 't5_with_bias', fallback=True) component_config.has_mlp_bias = config.getboolean('structure', 't5_with_bias', fallback=True) component_config.has_model_final_layernorm = config.getboolean( component, 'has_model_final_layernorm') component_config.layernorm_eps = config.getfloat(component, 'layer_norm_epsilon', fallback=False) normalize_before = config.getboolean(component, 'normalize_before') component_config.layernorm_position = layernorm_position_map[ 'pre_layernorm' if normalize_before else 'post_layernorm'] component_config.layernorm_type = layernorm_type_map[config.get( component, 'layernorm_type', fallback='LayerNorm')] component_config.hidden_act = config.get(component, 'activation_function') component_config.gated_act = config.getboolean(component, 'is_gated_act', fallback=False) component_config.mlp_type = mlp_type_map['GatedMLP' if component_config. gated_act else 'MLP'] component_config.relative_attention = config.get( 'structure', 'position_embedding_type') == 'relative' component_config.num_buckets = config.getint( component, 'relative_attention_num_buckets', fallback=0) component_config.max_distance = config.getint( component, 'relative_attention_max_distance', fallback=0) component_config.max_lora_rank = config.getint(component, 'max_lora_rank', fallback=0) component_config.lora_target_modules = literal_eval( config.get(component, 'lora_target_modules', fallback="[]")) component_config.hf_modules_to_trtllm_modules = literal_eval( config.get(component, 'hf_modules_to_trtllm_modules', fallback="{}")) component_config.trtllm_modules_to_hf_modules = literal_eval( config.get(component, 'trtllm_modules_to_hf_modules', fallback="{}")) component_config.logits_dtype = config.get(component, 'logits_dtype', fallback='float32') component_config.position_embedding_type = config.get( 'structure', 'position_embedding_type') if component == 'decoder': component_config.rescale_before_lm_head = config.getboolean( component, 'rescale_before_lm_head') component_config.encoder_hidden_size = config.getint( 'encoder', 'd_model') component_config.encoder_num_heads = config.getint( 'encoder', 'encoder_attention_heads') component_config.encoder_head_size = config.getint( 'encoder', 'd_kv', fallback=component_config.encoder_hidden_size // component_config.encoder_num_heads) # nougat has decoder_start_token_id = None, special handling decoder_start_token_id = config.get('decoder', 'decoder_start_token_id') component_config.decoder_start_token_id = int( decoder_start_token_id ) if decoder_start_token_id != "None" else None component_config.eos_token_id = config.getint( 'decoder', 'eos_token_id') component_config.bos_token_id = config.getint( 'decoder', 'bos_token_id') component_config.pad_token_id = config.getint( 'decoder', 'pad_token_id') return component_config encoder_config = None if not args.nougat: encoder_config = parse_bart_config_by_component(config, "encoder", args) decoder_config = parse_bart_config_by_component(config, "decoder", args) return encoder_config, decoder_config def convert_bart_weights_to_tllm_safetensors(config, component, params): weights = {} mapping = config.mapping hidden_size = config.hidden_size convert_weight_to_dtype(params, config.dtype) ffn_hidden_size = config.intermediate_size vocab_size = config.vocab_size hf_param_prefix = f'model.{component}' trtllm_layer_name = f'{component}_layers' trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention' trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm' embedding_layer_names = { 'embed_tokens.weight': { "name": 'embedding.vocab_embedding.weight', "shape": (vocab_size, -1) }, 'embed_positions.weight': { "name": 'embedding.position_embedding.weight', "shape": (config.max_position_embeddings, hidden_size) }, 'layernorm_embedding.weight': { "name": 'embedding.embedding_layernorm.weight', "shape": None }, 'layernorm_embedding.bias': { "name": 'embedding.embedding_layernorm.bias', "shape": None }, } hidden_layer_name_split = { 'self_attn.out_proj.weight': { "name": f'{trtllm_attn_layer_name}.dense.weight', "shape": (hidden_size, hidden_size // mapping.tp_size), "split_dim": -1 }, 'fc1.weight': { "name": 'mlp.fc.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, 'fc1.bias': { "name": 'mlp.fc.bias', "shape": (ffn_hidden_size // mapping.tp_size), "split_dim": 0 }, 'fc2.weight': { "name": 'mlp.proj.weight', "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), "split_dim": -1 }, } hidden_layer_name_no_split = { 'self_attn.out_proj.bias': { "name": f'{trtllm_attn_layer_name}.dense.bias', "shape": (hidden_size) }, 'self_attn_layer_norm.weight': { "name": f'{trtllm_attn_layernorm_name}.weight', "shape": None }, 'self_attn_layer_norm.bias': { "name": f'{trtllm_attn_layernorm_name}.bias', "shape": None }, 'fc2.bias': { "name": 'mlp.proj.bias', "shape": (hidden_size) }, 'final_layer_norm.weight': { "name": 'mlp_layernorm.weight', "shape": None }, 'final_layer_norm.bias': { "name": 'mlp_layernorm.bias', "shape": None }, } if config.model_type == 'mbart': hidden_layer_name_split['layer_norm.weight'] = { "name": 'final_layernorm.weight', "shape": None, "split_dim": 0 } hidden_layer_name_no_split['layer_norm.bias'] = { "name": 'final_layernorm.bias', "shape": None, "split_dim": 0 } if component == "decoder": hidden_layer_name_split.update({ 'encoder_attn.out_proj.weight': { "name": 'cross_attention.dense.weight', "shape": (hidden_size, hidden_size // mapping.tp_size), "split_dim": -1 } }) hidden_layer_name_no_split.update({ 'encoder_attn.out_proj.bias': { "name": 'cross_attention.dense.bias', "shape": (hidden_size) }, 'encoder_attn_layer_norm.weight': { "name": 'cross_attention_layernorm.weight', "shape": None }, 'encoder_attn_layer_norm.bias': { "name": 'cross_attention_layernorm.bias', "shape": None }, }) def get_attn_module_name(component, layer, attn_type): return f'model.{component}.layers.{int(layer)}.{attn_type}' for hf_weight_name, weight_info in embedding_layer_names.items(): if 'position' in hf_weight_name: weights[weight_info["name"]] = params[ f'{hf_param_prefix}.{hf_weight_name}'][2:].clone() else: weights[weight_info["name"]] = params[ f'{hf_param_prefix}.{hf_weight_name}'].clone() weights[weight_info["name"]] = reshape(weights[weight_info["name"]], weight_info["shape"]) num_layers = config.num_hidden_layers layers_range = mapping.pp_layers(num_layers) for layer_idx in layers_range: local_layer_idx = layer_idx - layers_range[0] hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}' trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' for hf_weight_name, weight_info in hidden_layer_name_split.items(): weights[ f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape( split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'], mapping.tp_size, mapping.tp_rank, dim=weight_info["split_dim"]), weight_info["shape"]) for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}' hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}' weights[trtllm_layer_fullname] = reshape( params[hf_layer_fullname].clone(), shape=weight_info["shape"]) self_attn_module_name = get_attn_module_name(component, layer_idx, 'self_attn') weights.update( fuse_qkv_one_layer( params, self_attn_module_name, f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', mapping.tp_size, mapping.tp_rank, config.model_type, (hidden_size * 3 // mapping.tp_size, hidden_size), (hidden_size * 3 // mapping.tp_size))) if component == 'decoder': cross_attn_module_name = get_attn_module_name( component, layer_idx, 'encoder_attn') weights.update( fuse_qkv_one_layer( params, cross_attn_module_name, f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size, mapping.tp_rank, config.model_type, (hidden_size * 3 // mapping.tp_size, hidden_size), (hidden_size * 3 // mapping.tp_size))) if component == 'decoder': weights['lm_head.weight'] = reshape( split(params['lm_head.weight'], mapping.tp_size, mapping.tp_rank, dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) if config.has_model_final_layernorm: weights['final_layernorm.weight'] = params[ f'{hf_param_prefix}.layer_norm.weight'].clone() weights['final_layernorm.bias'] = params[ f'{hf_param_prefix}.layer_norm.bias'].clone() return weights def parse_pix2struct_config(args, hf_model): # manually set q_scaling to offset attention scaling's effect. # TODO: modify kernels to control whether to disable attention scaling config = configparser.ConfigParser() def get_offset_q_scaling(config) -> str: d_model = config.hidden_size num_heads = config.num_heads head_size = d_model / num_heads scaling = 1 / head_size**.5 return str(scaling) config["decoder"] = {} for key, val in hf_model.decoder.config.to_dict().items(): config["decoder"][key] = f"{val}" config["decoder"]["q_scaling"] = get_offset_q_scaling( hf_model.decoder.config) config["structure"] = dict() config["structure"]["pix2struct_with_bias"] = "false" config["structure"]["use_gated_activation"] = "false" config["structure"]["position_embedding_type"] = "relative" config["structure"]["model_type"] = args.model_type def parse_pix2struct_config_by_component(config, component, args): if component == 'decoder': args.n_layer = config.getint(component, 'num_layers') args.n_head = config.getint(component, 'num_heads') args.head_size = config.getint(component, 'd_kv') args.hidden_size = config.getint(component, 'hidden_size') args.ffn_hidden_size = config.getint(component, 'd_ff') args.vocab_size = config.getint(component, 'vocab_size') args.n_positions = config.getint(component, 'n_positions', fallback=512) args.has_position_embedding = config.getboolean( component, 'has_position_embedding', fallback=False) # TODO: hardcoded here args.has_token_type_embedding = config.getboolean( component, 'has_token_type_embedding', fallback=False) args.has_embedding_layernorm = config.getboolean( component, 'has_embedding_layernorm', fallback=False) args.has_embedding_scale = config.getboolean(component, 'has_embedding_scale', fallback=False) args.q_scaling = config.getfloat(component, 'q_scaling', fallback=1.0) args.has_attention_qkvo_bias = config.getboolean( component, 'has_attention_qkvo_bias', fallback=False) args.has_mlp_bias = config.getboolean(component, 'has_mlp_bias', fallback=False) args.has_model_final_layernorm = config.getboolean( component, 'has_model_final_layernorm', fallback=True) args.layernorm_eps = config.getfloat(component, 'layer_norm_epsilon') args.layernorm_position = layernorm_position_map[config.get( component, 'layernorm_position', fallback='pre_layernorm')] # TODO: hardcoded here args.layernorm_type = layernorm_type_map[config.get( component, 'layernorm_type', fallback='RmsNorm')] args.hidden_act = config.get(component, 'dense_act_fn') args.gated_act = True args.mlp_type = mlp_type_map['GatedMLP' if args. gated_act else 'MLP'] args.has_lm_head_bias = config.getboolean( component, # TODO: T5 with bias 'has_lm_head_bias', fallback=False) args.relative_attention = config.getboolean(component, 'relative_attention', fallback=True) args.num_buckets = config.getint(component, 'relative_attention_num_buckets') args.max_distance = config.getint( component, 'relative_attention_max_distance') args.logits_dtype = config.get(component, 'logits_dtype', fallback='float32') args.rescale_before_lm_head = config.getboolean( component, 'tie_word_embeddings' ) # default is True (for T5), but False for Flan-T5 args.encoder_hidden_size = config.getint('decoder', 'hidden_size') args.encoder_num_heads = config.getint('decoder', 'num_heads') args.encoder_head_size = config.getint('decoder', 'd_kv') args.position_embedding_type = config.get( 'structure', 'position_embedding_type') args.decoder_start_token_id = config.getint( 'decoder', 'decoder_start_token_id') args.eos_token_id = config.getint('decoder', 'eos_token_id') bos_token_id = config.get('decoder', 'bos_token_id') # pix2struct does not have bos_token_id args.bos_token_id = int( bos_token_id) if bos_token_id != "None" else None args.pad_token_id = config.getint('decoder', 'pad_token_id') else: assert False, 'Unsupported component!' return args decoder_args = parse_pix2struct_config_by_component(config, "decoder", args) return None, decoder_args def convert_pix2struct_weights_to_tllm_safetensors(config, component, params): weights = {} mapping = config.mapping convert_weight_to_dtype(params, config.dtype) hidden_size = config.hidden_size ffn_hidden_size = config.intermediate_size num_layers = config.num_hidden_layers n_head = config.num_attention_heads head_size = config.head_size attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5 hf_param_prefix = f'{component}' trtllm_layer_name = f'{component}_layers' trtllm_attn_layer_name = 'self_attention' trtllm_attn_layernorm_name = 'self_attention_layernorm' def get_attn_module_name(component, layer, attn_type): return f'{component}.layer.{int(layer)}.{attn_type}.attention' weights['embedding.vocab_embedding.weight'] = reshape( params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None) layers_range = mapping.pp_layers(num_layers) for layer_idx in layers_range: local_layer_idx = layer_idx - layers_range[0] trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}' hidden_layer_name_split = { f'{hf_layer_name_prefix}.self_attention.attention.output.weight': { "name": f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight', "shape": (hidden_size, attention_hidden_size // mapping.tp_size), "split_dim": -1 }, f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight', "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), "split_dim": -1 }, f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, } hidden_layer_name_no_split = { f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': { "name": f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight', "shape": None }, f'{hf_layer_name_prefix}.mlp.layer_norm.weight': { "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight', "shape": None }, } if config.gated_act: hidden_layer_name_split.update({ f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': { "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight', "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), "split_dim": 0 }, }) hidden_layer_name_split.update({ f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight': { "name": f'{trtllm_layer_name_prefix}.cross_attention.dense.weight', "shape": (hidden_size, attention_hidden_size // mapping.tp_size), "split_dim": -1 }, }) hidden_layer_name_no_split.update({ f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight': { "name": f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight', "shape": None }, }) self_attn_module_name = get_attn_module_name( component, layer_idx, 'encoder_decoder_attention') weights.update( fuse_qkv_one_layer( params, self_attn_module_name, f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size, mapping.tp_rank, config.model_type, (attention_hidden_size * 3 // mapping.tp_size, hidden_size), None)) self_attn_module_name = get_attn_module_name(component, layer_idx, 'self_attention') weights.update( fuse_qkv_one_layer( params, self_attn_module_name, f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', mapping.tp_size, mapping.tp_rank, config.model_type, (attention_hidden_size * 3 // mapping.tp_size, hidden_size), None)) weights[ f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape( split( params[ f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight'] .T, mapping.tp_size, mapping.tp_rank, 0), (n_head // mapping.tp_size, config.num_buckets)) for hf_weight_name, weight_info in hidden_layer_name_split.items(): if hf_weight_name in params.keys(): weights[weight_info["name"]] = reshape( split(params[hf_weight_name], mapping.tp_size, mapping.tp_rank, dim=weight_info["split_dim"]), weight_info["shape"]) for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): if hf_weight_name in params.keys(): weights[weight_info["name"]] = reshape( params[hf_weight_name].clone(), shape=weight_info["shape"]) weights[f'final_layernorm.weight'] = reshape( params[f'{component}.final_layer_norm.weight'].clone(), None) weights['lm_head.weight'] = reshape( split(params[f'{component}.lm_head.weight'], mapping.tp_size, mapping.tp_rank, dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) if not config.use_implicit_relative_attention: weights[f'rel_attn_table'] = reshape( split( params[ f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight'] .T, mapping.tp_size, mapping.tp_rank, 0), (n_head // mapping.tp_size, config.num_buckets)) return weights def get_model(args): if args.model_type == "t5": model = T5ForConditionalGeneration.from_pretrained(args.model_dir) elif args.model_type == "nmt": from fairseq.models.transformer import TransformerModel model = TransformerModel.from_pretrained(args.model_dir) elif args.model_type == "bart": if args.nougat: model = VisionEncoderDecoderModel.from_pretrained(args.model_dir) model = model.get_decoder() else: model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir) elif args.model_type == "pix2struct": model = Pix2StructForConditionalGeneration.from_pretrained( args.model_dir) elif args.model_type == "blip2": model = Blip2ForConditionalGeneration.from_pretrained( args.model_dir).language_model return model def convert_checkpoint(args): model = get_model(args) saved_dir = Path(args.output_dir) saved_dir.mkdir(parents=True, exist_ok=True) encoder_saved_dir = saved_dir / "encoder" encoder_saved_dir.mkdir(parents=True, exist_ok=True) decoder_saved_dir = saved_dir / "decoder" decoder_saved_dir.mkdir(parents=True, exist_ok=True) world_size = args.tp_size * args.pp_size kv_cache_quant_algo = None quant_algo = None model_type = args.model_type if args.model_type != "blip2" else "t5" encoder_config, decoder_config = globals()[f'parse_{model_type}_config']( args, model) additional_settings = ["gated_act"] if not args.nougat and args.model_type != "pix2struct": tllm_encoder_config = { 'architecture': "EncoderModel", 'dtype': args.dtype, 'logits_dtype': encoder_config.logits_dtype, 'num_hidden_layers': encoder_config.n_layer, 'num_attention_heads': encoder_config.n_head, 'hidden_size': encoder_config.hidden_size, 'norm_epsilon': encoder_config.layernorm_eps, 'vocab_size': encoder_config.vocab_size, 'position_embedding_type': encoder_config.position_embedding_type, 'hidden_act': encoder_config.hidden_act, 'quantization': { 'quant_algo': quant_algo, 'kv_cache_quant_algo': kv_cache_quant_algo, }, 'mapping': { 'world_size': world_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size, }, 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, 'max_position_embeddings': encoder_config.n_positions, 'num_key_value_heads': encoder_config.n_head, 'head_size': encoder_config.head_size, 'has_position_embedding': encoder_config.has_position_embedding, 'layernorm_type': encoder_config.layernorm_type, 'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias, 'has_mlp_bias': encoder_config.has_mlp_bias, 'has_model_final_layernorm': encoder_config.has_model_final_layernorm, 'has_embedding_layernorm': encoder_config.has_embedding_layernorm, 'has_embedding_scale': encoder_config.has_embedding_scale, 'intermediate_size': encoder_config.ffn_hidden_size, 'q_scaling': encoder_config.q_scaling, 'layernorm_position': encoder_config.layernorm_position, 'mlp_type': encoder_config.mlp_type, 'relative_attention': encoder_config.relative_attention, 'max_distance': encoder_config.max_distance, 'num_buckets': encoder_config.num_buckets, 'model_type': encoder_config.model_type, } for additional_setting in additional_settings: if hasattr(encoder_config, additional_setting): tllm_encoder_config.update({ additional_setting: getattr(encoder_config, additional_setting) }) with (encoder_saved_dir / "config.json").open('w') as f: json.dump(tllm_encoder_config, f, indent=4) encoder_convert_args = dict(params=model.state_dict(), component="encoder") tllm_decoder_config = { 'architecture': "DecoderModel", 'dtype': args.dtype, 'logits_dtype': decoder_config.logits_dtype, 'num_hidden_layers': decoder_config.n_layer, 'num_attention_heads': decoder_config.n_head, 'hidden_size': decoder_config.hidden_size, 'norm_epsilon': decoder_config.layernorm_eps, 'vocab_size': decoder_config.vocab_size, 'position_embedding_type': decoder_config.position_embedding_type, 'hidden_act': decoder_config.hidden_act, 'quantization': { 'quant_algo': quant_algo, 'kv_cache_quant_algo': kv_cache_quant_algo, }, 'mapping': { 'world_size': world_size, 'tp_size': args.tp_size, 'pp_size': args.pp_size, }, 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, 'max_position_embeddings': decoder_config.n_positions, 'head_size': decoder_config.head_size, 'has_position_embedding': decoder_config.has_position_embedding, 'layernorm_type': decoder_config.layernorm_type, 'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias, 'has_mlp_bias': decoder_config.has_mlp_bias, 'has_model_final_layernorm': decoder_config.has_model_final_layernorm, 'has_embedding_layernorm': decoder_config.has_embedding_layernorm, 'has_embedding_scale': decoder_config.has_embedding_scale, 'intermediate_size': decoder_config.ffn_hidden_size, 'q_scaling': decoder_config.q_scaling, 'layernorm_position': decoder_config.layernorm_position, 'mlp_type': decoder_config.mlp_type, 'relative_attention': decoder_config.relative_attention, 'max_distance': decoder_config.max_distance, 'num_buckets': decoder_config.num_buckets, 'model_type': decoder_config.model_type, 'rescale_before_lm_head': decoder_config.rescale_before_lm_head, 'encoder_hidden_size': decoder_config.encoder_hidden_size, 'encoder_num_heads': decoder_config.encoder_num_heads, 'encoder_head_size': decoder_config.encoder_head_size, 'skip_cross_kv': args.skip_cross_kv, 'use_implicit_relative_attention': args.use_implicit_relative_attention, 'decoder_start_token_id': decoder_config.decoder_start_token_id, 'eos_token_id': decoder_config.eos_token_id, 'bos_token_id': decoder_config.bos_token_id, 'pad_token_id': decoder_config.pad_token_id, } for additional_setting in additional_settings: if hasattr(decoder_config, additional_setting): tllm_decoder_config.update({ additional_setting: getattr(decoder_config, additional_setting) }) with (decoder_saved_dir / "config.json").open('w') as f: json.dump(tllm_decoder_config, f, indent=4) decoder_convert_args = dict(params=model.state_dict(), component="decoder") if args.model_type == "nmt": fairseq_config = vars(model.cfg.model) # Namespace --> dict num_embeddings = fairseq_config['max_source_positions'] embedding_dim = fairseq_config['encoder_embed_dim'] padding_idx = model.models[0].encoder.embed_tokens.padding_idx # 1 sin_pos_embedding = model.models[ 0].encoder.embed_positions.get_embedding( padding_idx + 1 + num_embeddings, embedding_dim, padding_idx=padding_idx) # [2 + num_embeddings, embed_dim] sin_pos_embedding = sin_pos_embedding[2:, :] # remove offset embeddings encoder_convert_args["sin_pos_embedding"] = sin_pos_embedding decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding if args.workers == 1: if not args.nougat and args.model_type != "pix2struct": convert(0, world_size, args, tllm_encoder_config, encoder_convert_args, encoder_saved_dir) convert(0, world_size, args, tllm_decoder_config, decoder_convert_args, decoder_saved_dir) else: if args.workers > world_size: args.workers = world_size LOGGER.info(f'Convert checkpoint using {args.workers} workers.') import torch.multiprocessing as mp if not args.nougat and args.model_type != "pix2struct": mp.spawn(convert, nprocs=args.workers, args=(world_size, args, tllm_encoder_config, encoder_convert_args, encoder_saved_dir)) mp.spawn(convert, nprocs=args.workers, args=(world_size, args, tllm_decoder_config, decoder_convert_args, decoder_saved_dir)) def convert(worker_rank, world_size, args, model_config, convert_args, saved_dir): for rank in range(worker_rank, world_size, args.workers): rank_config = copy.deepcopy(PretrainedConfig.from_dict(model_config)) rank_config.set_rank(rank) weights = globals( )[f'convert_{rank_config.model_type}_weights_to_tllm_safetensors']( config=rank_config, **convert_args) safetensors.torch.save_file(weights, f'{saved_dir}/rank{rank}.safetensors') if __name__ == "__main__": parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter) parser.add_argument( '--model_type', type=str, default='t5', choices=['t5', 'nmt', 'bart', 'pix2struct', 'blip2'], help= 'Multimodal type when this script is used for multimodal conversion.') parser.add_argument('--tp_size', type=int, default=1, help='N-way tensor parallelism size') parser.add_argument('--pp_size', type=int, default=1, help='N-way pipeline parallelism size') parser.add_argument("--model_dir", "-i", type=str, help="Path to the framework checkpoint file", required=True) parser.add_argument("--output_dir", "-o", type=str, help="Path to the converted TRT-LLM model weight file", required=True) parser.add_argument( "--workers", type=int, help="How many workers to spawn for conversion (default: 4)", default=4) parser.add_argument("--nougat", action="store_true", help="Model which uses vision encoder + mbart decoder") parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") parser.add_argument( '--use_parallel_embedding', action="store_true", default=False, help= 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' ) parser.add_argument( '--embedding_sharding_dim', type=int, default=0, choices=[0, 1], help= 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' 'To shard it along hidden dimension, set embedding_sharding_dim=1' 'Note: embedding sharding is only enabled when embedding_sharding_dim = 0' ) parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8', 'int4'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--dtype', type=str, default='float16', choices=['float16', 'float32', 'bfloat16'], help= 'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.' ) parser.add_argument( '--skip_cross_kv', action='store_true', help= 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).' ) parser.add_argument( '--use_implicit_relative_attention', action='store_true', help= 'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.' ) args = parser.parse_args() log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) LOGGER.info("\n=============== Argument ===============") for key in vars(args): LOGGER.info(f"{key}: {vars(args)[key]}") LOGGER.info("========================================") start_time = datetime.now() convert_checkpoint(args) stop_time = datetime.now() run_time = (stop_time - start_time) LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time))