diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31305568d..a7b91eec34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,14 +21,14 @@ repos: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.4.0 + rev: 8.0.3 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.23 + rev: v0.24.1 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort @@ -37,17 +37,17 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.9.0" + rev: "v1.15.0" hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.11.7 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 25.1.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs @@ -57,7 +57,7 @@ repos: - id: typos - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.5.5 + rev: 0.7.1 hooks: # Update the uv lockfile - id: uv-lock diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 7fa943040e..0ed8772a44 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -200,6 +200,8 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG +elif args.quantize_type == "fp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point @@ -239,6 +241,8 @@ def calibrate_loop(model): enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} + elif args.quantize_type == "fp4": + enabled_precisions = {torch.float4_e2m1fn_x2} trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index c706c345d6..e0a78e1a0b 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -80,6 +80,12 @@ class dtype(Enum): :meta hide-value: """ + f4 = auto() + """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4`` + + :meta hide-value: + """ + uint8 = u8 int8 = i8 @@ -91,6 +97,9 @@ class dtype(Enum): float8 = f8 fp8 = f8 + float4 = f4 + fp4 = f4 + half = f16 fp16 = f16 float16 = f16 @@ -162,6 +171,8 @@ def _from( return dtype.i32 elif t == torch.float8_e4m3fn: return dtype.f8 + elif t == torch.float4_e2m1fn_x2: + return dtype.f4 elif t == torch.half: return dtype.f16 elif t == torch.float: @@ -188,6 +199,8 @@ def _from( return dtype.i8 elif t == trt.DataType.FP8: return dtype.f8 + elif t == trt.DataType.FP4: + return dtype.fp4 elif t == trt.DataType.INT32: return dtype.i32 elif t == trt.DataType.INT64: @@ -357,6 +370,8 @@ def to( return torch.long elif self == dtype.f8: return torch.float8_e4m3fn + elif self == dtype.f4: + return torch.float4_e2m1fn_x2 elif self == dtype.f16: return torch.half elif self == dtype.f32: @@ -394,6 +409,8 @@ def to( return trt.DataType.BOOL elif self == dtype.bf16: return trt.DataType.BF16 + elif self == dtype.f4: + return trt.DataType.FP4 elif use_default: return trt.DataType.FLOAT else: @@ -410,6 +427,8 @@ def to( return np.int64 elif self == dtype.f16: return np.float16 + elif self == dtype.f4: + return np.float4_e2m1fn_x2 elif self == dtype.f32: return np.float32 elif self == dtype.f64: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 379a196e2e..ded8adfb01 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -29,7 +29,14 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} +SUPPORTED_KERNEL_PRECISIONS = { + dtype.f32, + dtype.f16, + dtype.bf16, + dtype.i8, + dtype.f8, + dtype.f4, +} TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..08a1bb4ea4 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -617,6 +617,39 @@ def aten_ops_quantize_op( ) +try: + import modelopt.torch.quantization as mtq # noqa: F401 + + assert torch.ops.tensorrt.dynamic_block_quantize_op.default +except Exception as e: + _LOGGER.warning( + "Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models" + ) +else: + + @dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default) + def aten_ops_dynamic_block_quantize_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.nvfp4_quantize.nvfp4_quantize( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..1f2d9d0de1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -14,6 +14,7 @@ matmul, nccl_ops, normalization, + nvfp4_quantize, pad, permutation, pool, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py new file mode 100644 index 0000000000..2458350715 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -0,0 +1,399 @@ +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_trt_tensor, + to_torch, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def nvfp4_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + block_size: int, + amax: Union[np.ndarray, torch.Tensor], + num_bits: int, + exponent_bits: int, + scale_num_bits: int, + scale_exponent_bits: int, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to FP4 based + on the output_type set and dequantizes them back. + """ + print( + f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" + ) + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + ) + with unset_fake_temporarily(): + axis = len(input_tensor.shape) - 1 + global_scale = _calculate_global_scale(ctx, name, amax) + if ".weight_quantizer" in name: + _test_weights_scaling_factor(input_tensor, global_scale) + output = _static_double_quantize_without_constant_folding( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + elif ".input_quantizer" in name: + # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + output = _dynamic_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + + else: + raise ValueError( + f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" + ) + return output + + +def _dynamic_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int = -1, + block_size: int = 16, + output_type: trt.DataType = trt.DataType.FP4, + scale_type: trt.DataType = trt.DataType.FP8, +) -> TRTTensor: + """ + quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : Tensor (On GPU) + The input tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis : int + The axis to quantize. Default is -1 (the last axis). + block_size : int + The block size for quantization. Default is 16. + output_type : trt.DataType + The data type for quantized data. Default is FP4. + scale_type : trt.DataType + The data type for block scale. Default is FP8. + + """ + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + # dynamic quantize input tensor to fp4 + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis, + block_size, + output_type, + scale_type, + ) + dynamic_quantize_layer.set_input(1, global_scale) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) + quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + + # dequantize scale from fp8 to orignal dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize( + quantized_scale_in_fp8, global_scale, input_tensor.dtype + ) + set_layer_name( + dequantize_scale_layer, target, name + "_dequantize_scale", source_ir + ) + dequantized_scale = dequantize_scale_layer.get_output(0) + + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + quantized_data_in_fp4, dequantized_scale, input_tensor.dtype + ) + dequantize_data_layer.axis = axis + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +# TODO: to remove it this is to make sure our global scale and block scale calculation is correct during debugging +def _test_weights_scaling_factor(weights_tensor, global_scale): + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + import modelopt.onnx.quantization.quant_utils as quant_utils + + weights_scaling_factor_2 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor_2( + weights_tensor + ) + torch.allclose(weights_scaling_factor_2, global_scale) + + block_scale_f32 = quant_utils.get_weights_scaling_factor( + weights_tensor.numpy(), 16, np.float32(global_scale) + ) + block_scale_f32 = torch.from_numpy(block_scale_f32) + + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + torch.allclose(block_scale_f32, block_scale) + block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + + +def _static_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + # import modelopt.onnx.quantization.quant_utils as quant_utils + + block_scale_fp32 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn) + + global_scale = to_torch(global_scale, None) + + # # TODO: issue1: not sure whether we need to quantize the weights tensor here, due to Icast layer does not support cast + # IBuilder::buildSerializedNetwork: Error Code 4: API Usage Error (Cast ITensor linear1.weight_quantizer/dynamic_block_quantize_op_1_weights_tensor_scaled_output from DataType.FLOAT to DataType.FP4 - [unknown_ir_ops]-[linear1.weight_quantizer/dynamic_block_quantize_op_1_cast_weights_tensor_scaled_to_fp4]: unsupported input type and output type for ICastLayer, unsupported types are: {FP8, Int4, FP4}, current input type: Float, output type: FP4) + # reference https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/qdq_utils.py#L955 + # weights_tensor_scaled = quant_utils.quantize(weights_tensor.numpy(), 16, block_scale_fp32.numpy(),global_scale.numpy()) + # weights_tensor_scaled = torch.from_numpy(weights_tensor_scaled) + # weights_tensor_scaled = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_tensor_scaled") + # weights_fp4 = cast_trt_tensor(ctx, weights_tensor_scaled, trt.DataType.FP4, name + "_cast_weights_tensor_scaled_to_fp4") + + # # TODO: issue2: weights_tensor_scaled is in torch.uint8 format not sure how can this to be converted into float4_e2m1fn_x2 + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/core/torch/quantization/qtensor/nvfp4_tensor.py#L136 + weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale_fp32, + global_scale, + )[0]._quantized_data + + # # TODO: issue3: torch does not support convert to float4_e2m1fn_x2 directly got RuntimeError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' + # weights_fp4 = weights_tensor_scaled.to(torch.float4_e2m1fn_x2) + # weights_fp4 = get_trt_tensor(ctx, weights_fp4, name + "_weights_fp4") + + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + block_scale = get_trt_tensor(ctx, block_scale_fp32, name + "_block_scale") + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + # # quantize block scale to fp8 + # block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) + # set_layer_name( + # block_scale_quantize_layer, + # target, + # name + "_block_scale_quantize", + # source_ir, + # ) + # block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) + # quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + + # dequantize block scale from fp8 to float32 + dequantize_block_scale_layer = ctx.net.add_dequantize( + block_scale_fp8, + global_scale, + block_scale.dtype, + ) + set_layer_name( + dequantize_block_scale_layer, + target, + name + "_dequantize_block_scale", + source_ir, + ) + dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + + # dequantize weights tensor from fp4 to originaldtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + weights_fp4, + dequantized_block_scale, + trt.DataType.FLOAT, + ) + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +def _static_double_quantize_without_constant_folding( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + # import modelopt.onnx.quantization.quant_utils as quant_utils + + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + global_scale = to_torch(global_scale, None) + + # block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + # block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") + weights_tensor = get_trt_tensor(ctx, weights_tensor, name + "_weights_tensor") + + # quantize block scale to fp8 + block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) + set_layer_name( + block_scale_quantize_layer, + target, + name + "_block_scale_quantize_to_fp8", + source_ir, + ) + block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) + block_scale_fp8 = block_scale_quantize_layer.get_output(0) + + # dequantize block scale from fp8 to float32 + dequantize_block_scale_layer = ctx.net.add_dequantize( + block_scale_fp8, + global_scale, + block_scale.dtype, + ) + set_layer_name( + dequantize_block_scale_layer, + target, + name + "_dequantize_block_scale_from_fp8", + source_ir, + ) + dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + + # quantize weights tensor to fp4 + quantize_weights_layer = ctx.net.add_quantize( + weights_tensor, dequantized_block_scale + ) + set_layer_name( + quantize_weights_layer, + target, + name + "_quantize_weights_to_fp4", + source_ir, + ) + quantize_weights_layer.set_output_type(0, trt.DataType.FP4) + weights_fp4 = quantize_weights_layer.get_output(0) + + # dequantize weights tensor from fp4 to originaldtype(default is float32) + dequantize_weights_layer = ctx.net.add_dequantize( + weights_fp4, + dequantized_block_scale, + trt.DataType.FLOAT, + ) + set_layer_name( + dequantize_weights_layer, + target, + name + "_dequantize_weights_from_fp4", + source_ir, + ) + dequantized_data = dequantize_weights_layer.get_output(0) + return dequantized_data + + +def _calculate_global_scale( + ctx: ConversionContext, + name: str, + amax: torch.Tensor, +) -> torch.Tensor: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + if amax is None or amax == 0: + amax = 1.0 + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6 * 448) + if global_scale == 0: + global_scale = 1.0 + return global_scale + + +def _calculate_block_scale( + ctx: ConversionContext, + name: str, + weights_tensor: TRTTensor, + block_size: int, +) -> TRTTensor: + amax = weights_tensor.abs().amax().float() + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/quant_utils.py#L122 + weights_scaling_factor_2 = amax / 6 / 448 + if weights_scaling_factor_2 == 0: + weights_scaling_factor_2 = 1.0 + + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/quant_utils.py#L131 + [n, k] = weights_tensor.shape[-2:] + assert block_size != 0, "block_size must be non-zero" + assert k % block_size == 0, "k must be a multiple of block_size" + reshaped_input_tensor = weights_tensor.reshape( + tuple(weights_tensor.shape[:-2]) + (n, k // block_size, block_size) + ) + + per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() + per_block_scale = torch.divide(per_block_amax, 6) + q_per_block_scale = torch.divide(per_block_scale, weights_scaling_factor_2) + # TODO:set all zero values in scale to 1.0 + # block_scale = get_trt_tensor(ctx, q_per_block_scale, name + "_block_scale") + return q_per_block_scale diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index e472ed3092..192f9c648a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -28,6 +28,8 @@ def quantize( """ with unset_fake_temporarily(): + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input") if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( trt.float32, trt.float16, @@ -67,3 +69,252 @@ def quantize( dq_output = dequantize_layer.get_output(0) return dq_output + + +# def nvfp4_quantize( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor: TRTTensor, +# block_size: int, +# amax: Union[np.ndarray, torch.Tensor], +# num_bits: int, +# exponent_bits: int, +# scale_num_bits: int, +# scale_exponent_bits: int, +# ) -> TRTTensor: +# """ +# Adds quantize and dequantize ops (QDQ) which quantize to FP4 based +# on the output_type set and dequantizes them back. +# """ +# print( +# f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" +# ) +# with unset_fake_temporarily(): +# if input_tensor.dtype not in ( +# trt.float32, +# trt.float16, +# trt.bfloat16, +# torch.float32, +# torch.float16, +# torch.bfloat16, +# ): +# raise ValueError( +# f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" +# ) +# if len(input_tensor.shape) not in (2, 3): +# raise ValueError( +# f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" +# ) +# axis = len(input_tensor.shape) - 1 + +# # TODO: ADD PADDING IF NEEDED +# # TODO: ADD DYNAMIC SHAPE SUPPORT + +# global_scale = _calculate_global_scale(ctx, name, amax) + +# if ".weight_quantizer" in name: +# block_scale = _calculate_block_scale( +# ctx, +# name, +# input_tensor, +# block_size, +# ) +# input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") +# output = _static_double_quantize( +# ctx, +# target, +# source_ir, +# name, +# input_tensor, +# block_scale, +# global_scale, +# ) +# elif ".input_quantizer" in name: +# # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 +# output = _dynamic_double_quantize( +# ctx, +# target, +# source_ir, +# name, +# input_tensor, +# global_scale, +# ) + +# else: +# raise ValueError( +# f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" +# ) +# return output + + +# def _dynamic_double_quantize( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor: TRTTensor, +# global_scale: TRTTensor, +# axis: int = -1, +# block_size: int = 16, +# output_type: trt.DataType = trt.DataType.FP4, +# scale_type: trt.DataType = trt.DataType.FP8, +# ) -> TRTTensor: +# """ +# quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 +# Parameters: +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR] +# name: str +# input_tensor : Tensor (On GPU) +# The input tensor. +# global_scale : Tensor (On GPU) +# The global per-tensor scaling factor. It should contain only 1 element. +# axis : int +# The axis to quantize. Default is -1 (the last axis). +# block_size : int +# The block size for quantization. Default is 16. +# output_type : trt.DataType +# The data type for quantized data. Default is FP4. +# scale_type : trt.DataType +# The data type for block scale. Default is FP8. + +# """ +# # dynamic quantize input tensor to fp4 +# dynamic_quantize_layer = ctx.net.add_dynamic_quantize( +# input_tensor, +# axis, +# block_size, +# output_type, +# scale_type, +# ) +# dynamic_quantize_layer.set_input(1, global_scale) +# set_layer_name( +# dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir +# ) +# quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) +# quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + +# # dequantize scale from fp8 to orignal dtype(default is float32) +# dequantize_scale_layer = ctx.net.add_dequantize( +# quantized_scale_in_fp8, global_scale, input_tensor.dtype +# ) +# set_layer_name( +# dequantize_scale_layer, target, name + "_dequantize_scale", source_ir +# ) +# dequantized_scale = dequantize_scale_layer.get_output(0) + +# # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) +# dequantize_data_layer = ctx.net.add_dequantize( +# quantized_data_in_fp4, dequantized_scale, input_tensor.dtype +# ) +# dequantize_data_layer.axis = axis +# set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) +# dequantized_data = dequantize_data_layer.get_output(0) +# return dequantized_data + + +# def _static_double_quantize( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor: TRTTensor, +# block_scale: TRTTensor, +# global_scale: TRTTensor, +# ) -> TRTTensor: +# """ +# Parameters: +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor : Tensor (On GPU) +# The input tensor. +# block_scale : Tensor (On GPU) +# The per-block scaling factor. +# global_scale : Tensor (On GPU) +# The global per-tensor scaling factor. It should contain only 1 element. +# Returns: +# A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8 +# """ +# # quantize block scale to fp8 +# block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) +# set_layer_name( +# block_scale_quantize_layer, +# target, +# name + "_block_scale_quantize", +# source_ir, +# ) +# block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) +# quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + +# # dequantize block scale from fp8 to original dtype(default is float32) +# dequantize_block_scale_layer = ctx.net.add_dequantize( +# quantized_block_scale_in_fp8, +# global_scale, +# block_scale.dtype, +# ) +# set_layer_name( +# dequantize_block_scale_layer, +# target, +# name + "_dequantize_block_scale", +# source_ir, +# ) +# dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + +# # quantize input tensor to fp4 +# data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale) +# set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir) +# data_quantize_layer.set_output_type(0, trt.DataType.FP4) +# quantized_data_in_fp4 = data_quantize_layer.get_output(0) + +# # dequantize input tensor from fp4 to originaldtype(default is float32) +# dequantize_data_layer = ctx.net.add_dequantize( +# quantized_data_in_fp4, +# dequantized_block_scale, +# input_tensor.dtype, +# ) +# set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) +# dequantized_data = dequantize_data_layer.get_output(0) +# return dequantized_data + + +# def _calculate_global_scale( +# ctx: ConversionContext, +# name: str, +# amax: TRTTensor, +# ) -> TRTTensor: +# # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) +# if amax is None or amax == 0: +# amax = 1.0 +# amax = to_torch( +# amax, None +# ) # amax is calculated from input_tensor.abs().amax().float() +# global_scale = torch.divide(amax, 6) # 6*448 +# global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") +# return global_scale + + +# def _calculate_block_scale( +# ctx: ConversionContext, +# name: str, +# input_tensor: TRTTensor, +# block_size: int, +# ) -> TRTTensor: + +# [n, k] = input_tensor.shape[-2:] +# assert block_size != 0, "block_size must be non-zero" +# assert k % block_size == 0, "k must be a multiple of block_size" +# reshaped_input_tensor = input_tensor.reshape( +# tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) +# ) +# amax = input_tensor.abs().amax().float() +# amax = torch.divide(amax, 6*448) +# block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() +# block_scale = torch.divide(block_amax, 6) +# block_scale = torch.divide(block_scale, amax) +# block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") +# return block_scale diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..190b6752b4 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -101,4 +101,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: + if node.target in ( + torch.ops.tensorrt.quantize_op.default, + torch.ops.tensorrt.dynamic_block_quantize_op.default, + ): + return True return False diff --git a/pyproject.toml b/pyproject.toml index 87c70fec17..3bb857e3e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dev = [ torchvision = [ "torchvision", ] #Leaving torchvisions dependency unconstrained so uv can just install something that should work for the torch we have. TV's on PyT makes it hard to put version constrains in -quantization = ["nvidia-modelopt[deploy,hf,torch]>=0.17.0"] +quantization = ["nvidia-modelopt[all]>=0.27.1"] monitoring-tools = ["rich>=13.7.1"] jupyter = ["rich[jupyter]>=13.7.1"] distributed = ["tensorrt-llm>=0.16.0"] diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 0c28b23bba..52d746f995 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,6 +199,59 @@ def test_resnet18_half(ir): torch._dynamo.reset() +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (10, 0), +# "FP4 quantization requires compute capability 10.0 or later", +# ) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=16, out_features=3) + + def forward(self, x): + x = self.linear1(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(5, 16).cuda() + print(f"lan added amax: {input_tensor.abs().amax()}") + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float4_e2m1fn_x2}, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + print(f"lan added outputs_trt: {outputs_trt}") + print(f"lan added output_pyt: {output_pyt}") + assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1) + + @unittest.skipIf( torch.cuda.get_device_capability() < (8, 9), "FP8 quantization requires compute capability 8.9 or later", @@ -230,8 +283,8 @@ def calibrate_loop(model): input_tensor = torch.randn(1, 10).cuda() model = SimpleNetwork().eval().cuda() - quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point output_pyt = model(input_tensor)