Open
Description
System Info
System Info
- CPU architecture: x86_64
- Host Memory: 1TB
- GPU: NVIDIA A100 80GB x8
- TensorRT-LLM version: v0.18.2
- Triton container:
nvcr.io/nvidia/tritonserver:25.04-trtllm-python-py3
- CUDA: 12.2
- cuBLAS: 12.2
- NVIDIA driver: 550.54.14
- OS: Ubuntu 22.04
- Docker used with
--gpus all
,--shm-size=16GB
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- Convert the base model
Qwen2.5-7B-Instruct
using TensorRT-LLM:
python convert_checkpoint.py --model_dir Qwen2.5-7B-Instruct \
--output_dir ./tllm_checkpoint_1gpu_fp16 \
--dtype float16
- Build the engine with LoRA support:
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \
--output_dir ./engine \
--gemm_plugin float16 \
--max_batch_size 4 \
--gpt_attention_plugin float16 \
--kv_cache_type paged \
--remove_input_padding enable \
--use_paged_context_fmha enable \
--lora_plugin float16 \
--lora_target_modules attn_q attn_k attn_v mlp_4h_to_h mlp_h_to_4h mlp_gate attn_dense \
--max_lora_rank 20
- Download and convert a LoRA (example: doubleyyh/email-tuned-qwen2-lora):
huggingface-cli download doubleyyh/email-tuned-qwen2-lora --local-dir loras/email-lora-0
python hf_lora_convert.py -i loras/email-lora-0 -o loras/email-lora-0-converted
# Do NOT run fix_loras.py
- Launch the Triton server with the generated engine:
python3 launch_triton_server.py --world_size=1 --model_repo=/triton_model_repo
- Send a query:
python inflight_batcher_llm_client.py \
--text "What is the capital of France?" \
--tokenizer-dir Qwen2.5-7B-Instruct \
--lora-task-id 0 \
--lora-path loras/email-lora-0-converted/
Output error
[StatusCode.INVALID_ARGUMENT] [request id: <id_unknown>] unexpected shape for input 'lora_config' for model 'tensorrt_llm'. Expected [-1,-1,3], got [1,196,4].
NOTE: Setting a non-zero max_batch_size in the model config requires a batch dimension to be prepended to each input shape.
Expected behavior
The LoRA adapter should work out-of-the-box using hf_lora_convert.py. The model.lora_config.npy should conform to the required [1, N, 3] shape and be accepted by the Triton server.
actual behavior
The output of hf_lora_convert.py produces a model.lora_config.npy file with shape [1, N, 4]. This causes the Triton server to reject the LoRA config with an INVALID_ARGUMENT error. Only after manually running fix_loras.py to truncate and reshape does it pass shape validation.
additional notes
# fix_loras.py
#!/usr/bin/env python3
import argparse
import os
import shutil
import numpy as np
def fix_lora_config(input_path, output_path):
# Load
config = np.load(input_path)
print(f"Original config shape: {config.shape}")
# Remove batch dimension if exists
if config.ndim == 3 and config.shape[0] == 1:
config = np.squeeze(config, axis=0)
# Truncate to 3 columns
if config.shape[1] > 3:
config = config[:, :3]
if config.ndim != 2 or config.shape[1] != 3:
raise ValueError(f"Invalid config shape after cleanup: {config.shape} (expected (N, 3))")
# Add back batch dimension
config = np.expand_dims(config, axis=0)
print(f"Fixed config shape: {config.shape}")
return config
def main():
parser = argparse.ArgumentParser(description="Fix LoRA .npy files for TensorRT-LLM compatibility.")
parser.add_argument("-i", "--input", required=True, help="Input directory with LoRA files")
parser.add_argument("-o", "--output", required=True, help="Output directory for fixed LoRA files")
args = parser.parse_args()
input_dir = os.path.abspath(args.input)
output_dir = os.path.abspath(args.output)
if not os.path.isdir(input_dir):
raise FileNotFoundError(f"Input directory not found: {input_dir}")
os.makedirs(output_dir, exist_ok=True)
# Check for model.lora_weights.npy
weight_file = os.path.join(input_dir, "model.lora_weights.npy")
if not os.path.isfile(weight_file):
raise FileNotFoundError(f"{weight_file} not found")
shutil.copy2(weight_file, os.path.join(output_dir, "model.lora_weights.npy"))
print(f"Copied weights to: {output_dir}/model.lora_weights.npy")
# Fix and write model.lora_config.npy
config_file = os.path.join(input_dir, "model.lora_config.npy")
if not os.path.isfile(config_file):
raise FileNotFoundError(f"{config_file} not found")
fixed_config = fix_lora_config(config_file, output_dir)
np.save(os.path.join(output_dir, "model.lora_config.npy"), fixed_config)
print(f"Saved fixed config to: {output_dir}/model.lora_config.npy")
print("LoRA conversion complete.")
if __name__ == "__main__":
main()