Skip to content

lora_config shape mismatch when using converted LoRA at runtime #750

Open
@paulhendricks

Description

@paulhendricks

System Info

System Info

  • CPU architecture: x86_64
  • Host Memory: 1TB
  • GPU: NVIDIA A100 80GB x8
  • TensorRT-LLM version: v0.18.2
  • Triton container: nvcr.io/nvidia/tritonserver:25.04-trtllm-python-py3
  • CUDA: 12.2
  • cuBLAS: 12.2
  • NVIDIA driver: 550.54.14
  • OS: Ubuntu 22.04
  • Docker used with --gpus all, --shm-size=16GB

Who can help?

@byshiue @kaiyux

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Convert the base model Qwen2.5-7B-Instruct using TensorRT-LLM:
python convert_checkpoint.py --model_dir Qwen2.5-7B-Instruct \
                             --output_dir ./tllm_checkpoint_1gpu_fp16 \
                             --dtype float16
  1. Build the engine with LoRA support:
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \
             --output_dir ./engine \
             --gemm_plugin float16 \
             --max_batch_size 4 \
             --gpt_attention_plugin float16 \
             --kv_cache_type paged \
             --remove_input_padding enable \
             --use_paged_context_fmha enable \
             --lora_plugin float16 \
             --lora_target_modules attn_q attn_k attn_v mlp_4h_to_h mlp_h_to_4h mlp_gate attn_dense \
             --max_lora_rank 20
  1. Download and convert a LoRA (example: doubleyyh/email-tuned-qwen2-lora):
huggingface-cli download doubleyyh/email-tuned-qwen2-lora --local-dir loras/email-lora-0

python hf_lora_convert.py -i loras/email-lora-0 -o loras/email-lora-0-converted
# Do NOT run fix_loras.py
  1. Launch the Triton server with the generated engine:
python3 launch_triton_server.py --world_size=1 --model_repo=/triton_model_repo
  1. Send a query:
python inflight_batcher_llm_client.py \
  --text "What is the capital of France?" \
  --tokenizer-dir Qwen2.5-7B-Instruct \
  --lora-task-id 0 \
  --lora-path loras/email-lora-0-converted/

Output error

[StatusCode.INVALID_ARGUMENT] [request id: <id_unknown>] unexpected shape for input 'lora_config' for model 'tensorrt_llm'. Expected [-1,-1,3], got [1,196,4].
NOTE: Setting a non-zero max_batch_size in the model config requires a batch dimension to be prepended to each input shape.

Expected behavior

The LoRA adapter should work out-of-the-box using hf_lora_convert.py. The model.lora_config.npy should conform to the required [1, N, 3] shape and be accepted by the Triton server.

actual behavior

The output of hf_lora_convert.py produces a model.lora_config.npy file with shape [1, N, 4]. This causes the Triton server to reject the LoRA config with an INVALID_ARGUMENT error. Only after manually running fix_loras.py to truncate and reshape does it pass shape validation.

additional notes

# fix_loras.py
#!/usr/bin/env python3

import argparse
import os
import shutil
import numpy as np


def fix_lora_config(input_path, output_path):
    # Load
    config = np.load(input_path)
    print(f"Original config shape: {config.shape}")

    # Remove batch dimension if exists
    if config.ndim == 3 and config.shape[0] == 1:
        config = np.squeeze(config, axis=0)

    # Truncate to 3 columns
    if config.shape[1] > 3:
        config = config[:, :3]

    if config.ndim != 2 or config.shape[1] != 3:
        raise ValueError(f"Invalid config shape after cleanup: {config.shape} (expected (N, 3))")

    # Add back batch dimension
    config = np.expand_dims(config, axis=0)

    print(f"Fixed config shape: {config.shape}")
    return config


def main():
    parser = argparse.ArgumentParser(description="Fix LoRA .npy files for TensorRT-LLM compatibility.")
    parser.add_argument("-i", "--input", required=True, help="Input directory with LoRA files")
    parser.add_argument("-o", "--output", required=True, help="Output directory for fixed LoRA files")
    args = parser.parse_args()

    input_dir = os.path.abspath(args.input)
    output_dir = os.path.abspath(args.output)

    if not os.path.isdir(input_dir):
        raise FileNotFoundError(f"Input directory not found: {input_dir}")

    os.makedirs(output_dir, exist_ok=True)

    # Check for model.lora_weights.npy
    weight_file = os.path.join(input_dir, "model.lora_weights.npy")
    if not os.path.isfile(weight_file):
        raise FileNotFoundError(f"{weight_file} not found")
    shutil.copy2(weight_file, os.path.join(output_dir, "model.lora_weights.npy"))
    print(f"Copied weights to: {output_dir}/model.lora_weights.npy")

    # Fix and write model.lora_config.npy
    config_file = os.path.join(input_dir, "model.lora_config.npy")
    if not os.path.isfile(config_file):
        raise FileNotFoundError(f"{config_file} not found")

    fixed_config = fix_lora_config(config_file, output_dir)
    np.save(os.path.join(output_dir, "model.lora_config.npy"), fixed_config)
    print(f"Saved fixed config to: {output_dir}/model.lora_config.npy")

    print("LoRA conversion complete.")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions