Skip to content

Test Only fp4: Lluo/fp4 try out #3521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1d172ce
Add fp4 support
lanluo-nvidia Apr 29, 2025
d2b1422
test
lanluo-nvidia Apr 30, 2025
38617b4
Merge branch 'main' into lluo/fp4
lanluo-nvidia May 1, 2025
d439d96
upgrade modelopt
lanluo-nvidia May 1, 2025
5a2213e
add constant fold
lanluo-nvidia May 2, 2025
fcf0c12
fix the input tensor type issue
lanluo-nvidia May 2, 2025
057f35a
test
lanluo-nvidia May 5, 2025
7b09862
test
lanluo-nvidia May 6, 2025
d9f2ad9
test
lanluo-nvidia May 6, 2025
6892a47
test
lanluo-nvidia May 6, 2025
5198f9a
Merge branch 'main' into lluo/fp4
lanluo-nvidia May 7, 2025
559ada5
restructure the dynamic double quantize and static double quantize code
lanluo-nvidia May 7, 2025
bba1d79
add test code
lanluo-nvidia May 14, 2025
f16e58a
test
lanluo-nvidia May 14, 2025
06c8126
test
lanluo-nvidia May 14, 2025
868949c
test
lanluo-nvidia May 15, 2025
5134a2c
test
lanluo-nvidia May 15, 2025
391c971
test
lanluo-nvidia May 15, 2025
38297bd
test
lanluo-nvidia May 15, 2025
095251f
add print graph
lanluo-nvidia May 15, 2025
5830211
test float16
lanluo-nvidia May 16, 2025
8f57c86
change to float16
lanluo-nvidia May 16, 2025
24d0602
upgrade to 10.10.0 for tensorrt
lanluo-nvidia May 18, 2025
58d158c
use strongly typed network
lanluo-nvidia May 18, 2025
5a622ae
test
lanluo-nvidia May 18, 2025
8880b14
print out internal weight scaling value
lanluo-nvidia May 18, 2025
7910e25
Merge branch 'main' into lluo/fp4_try_out
lanluo-nvidia May 20, 2025
bae9188
add disable flag
lanluo-nvidia May 21, 2025
7b3cd74
add transpose
lanluo-nvidia May 23, 2025
1e2fccb
try different axis
lanluo-nvidia May 23, 2025
c5e3498
add disable gemm option
lanluo-nvidia May 23, 2025
925a68d
test
lanluo-nvidia May 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/keith/pre-commit-buildifier
rev: 6.4.0
rev: 8.0.3
hooks:
- id: buildifier
args:
- --warnings=all
- id: buildifier-lint
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.23
rev: v0.24.1
hooks:
- id: validate-pyproject
- repo: https://github.com/pycqa/isort
Expand All @@ -37,17 +37,17 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.15.0"
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
rev: v0.11.7
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 25.1.0
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand All @@ -57,7 +57,7 @@ repos:
- id: typos
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.5
rev: 0.7.1
hooks:
# Update the uv lockfile
- id: uv-lock
Expand Down
4 changes: 2 additions & 2 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ http_archive(
http_archive(
name = "tensorrt",
build_file = "@//third_party/tensorrt/archive:BUILD",
strip_prefix = "TensorRT-10.9.0.34",
strip_prefix = "TensorRT-10.10.0.31",
urls = [
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.9.0/tars/TensorRT-10.9.0.34.Linux.x86_64-gnu.cuda-12.8.tar.gz",
"https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.10.0/tars/TensorRT-10.10.0.31.Linux.x86_64-gnu.cuda-12.9.tar.gz",
],
)

Expand Down
12 changes: 11 additions & 1 deletion examples/dynamo/vgg16_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def vgg16(num_classes=1000, init_weights=False):

data = iter(training_dataloader)
images, _ = next(data)

crit = nn.CrossEntropyLoss()

# %%
Expand Down Expand Up @@ -200,8 +199,11 @@ def calibrate_loop(model):
quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
quant_cfg = mtq.FP8_DEFAULT_CFG
elif args.quantize_type == "fp4":
quant_cfg = mtq.NVFP4_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)

# model has FP8 qdq nodes at this point

# %%
Expand Down Expand Up @@ -233,12 +235,20 @@ def calibrate_loop(model):
with export_torch_mode():
# Compile the model with Torch-TensorRT Dynamo backend
input_tensor = images.cuda()
torch.onnx.export(model, input_tensor, "mtq_vgg16_model.onnx")

exp_program = torch.export.export(model, (input_tensor,), strict=False)
if args.quantize_type == "int8":
enabled_precisions = {torch.int8}
elif args.quantize_type == "fp8":
enabled_precisions = {torch.float8_e4m3fn}
elif args.quantize_type == "fp4":
enabled_precisions = {
torch.float4_e2m1fn_x2,
torch.float8_e4m3fn,
torch.float16,
torch.float32,
}
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
Expand Down
19 changes: 19 additions & 0 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class dtype(Enum):
:meta hide-value:
"""

f4 = auto()
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``

:meta hide-value:
"""

uint8 = u8
int8 = i8

Expand All @@ -91,6 +97,9 @@ class dtype(Enum):
float8 = f8
fp8 = f8

float4 = f4
fp4 = f4

half = f16
fp16 = f16
float16 = f16
Expand Down Expand Up @@ -162,6 +171,8 @@ def _from(
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.float4_e2m1fn_x2:
return dtype.f4
elif t == torch.half:
return dtype.f16
elif t == torch.float:
Expand All @@ -188,6 +199,8 @@ def _from(
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.FP4:
return dtype.fp4
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
Expand Down Expand Up @@ -357,6 +370,8 @@ def to(
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f4:
return torch.float4_e2m1fn_x2
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
Expand Down Expand Up @@ -394,6 +409,8 @@ def to(
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif self == dtype.f4:
return trt.DataType.FP4
elif use_default:
return trt.DataType.FLOAT
else:
Expand All @@ -410,6 +427,8 @@ def to(
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f4:
return np.float4_e2m1fn_x2
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
Expand Down
14 changes: 7 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,13 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
):
raise AssertionError(
f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
)
# if use_explicit_typing:
# if len(enabled_precisions) != 1 or not any(
# x in enabled_precisions for x in {torch.float32, dtype.f32}
# ):
# raise AssertionError(
# f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
# )

if use_fp32_acc:
logger.debug(
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
SUPPORTED_KERNEL_PRECISIONS = {
dtype.f32,
dtype.f16,
dtype.bf16,
dtype.i8,
dtype.f8,
dtype.f4,
}
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
Expand Down
17 changes: 10 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(

flag = 0
if compilation_settings.use_explicit_typing:
_LOGGER.info("Using strongly typed network definition")
STRONGLY_TYPED = 1 << (int)(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED
)
Expand Down Expand Up @@ -274,17 +275,19 @@ def _populate_trt_builder_config(
self.compilation_settings.dla_global_dram_size,
)

if dtype.float16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.FP16)
# if dtype.float16 in self.compilation_settings.enabled_precisions:
# builder_config.set_flag(trt.BuilderFlag.FP16)

if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)

if dtype.fp8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.FP8)
# if dtype.fp8 in self.compilation_settings.enabled_precisions:
# builder_config.set_flag(trt.BuilderFlag.FP8)

if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)
# if dtype.fp4 in self.compilation_settings.enabled_precisions:
# builder_config.set_flag(trt.BuilderFlag.FP4)
# if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
# builder_config.set_flag(trt.BuilderFlag.BF16)

if self.compilation_settings.sparse_weights:
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
Expand Down Expand Up @@ -351,7 +354,7 @@ def _populate_trt_builder_config(
builder_config.l2_limit_for_tiling = (
self.compilation_settings.l2_limit_for_tiling
)

print(f"lan added builder_config:{builder_config=}")
return builder_config

def _create_timing_cache(
Expand Down
33 changes: 33 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,39 @@ def aten_ops_quantize_op(
)


try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models"
)
else:

@dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.nvfp4_quantize.nvfp4_quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
def aten_ops_squeeze(
Expand Down
28 changes: 26 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,37 @@ def create_constant(
shape = list(torch_value.shape)

if torch_value is not None:
if torch_value.dtype == torch.float8_e4m3fn:
weights = trt.Weights(
type=trt.DataType.FP8,
ptr=torch_value.data_ptr(),
count=torch_value.numel(),
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
return constant.get_output(0)
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
if torch_value.dtype == torch.uint8:
weights = trt.Weights(
type=trt.DataType.FP4,
ptr=torch_value.data_ptr(),
count=torch_value.numel() * 2,
)
shape[-1] = shape[-1] * 2
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
return constant.get_output(0)
if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()

ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
constant = ctx.net.add_constant(
shape,
Expand All @@ -381,7 +406,6 @@ def create_constant(
trt.DataType.BF16,
name + "_bf16_cast",
)

return constant.get_output(0)
else:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
matmul,
nccl_ops,
normalization,
nvfp4_quantize,
pad,
permutation,
pool,
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor

import os

def addmm(
ctx: ConversionContext,
Expand All @@ -21,6 +21,10 @@ def addmm(
beta: Union[float, int],
alpha: Union[float, int],
) -> TRTTensor:
if os.getenv("DISABLE_GEMM", "false").lower() == "true":
print("lan added disable_gemm is set, skip addmm and returning mat2")
return mat2
print("lan added disable_gemm is not set, doing addmm")
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
if alpha != 1:
mm = impl.elementwise.mul(
Expand Down
Loading
Loading