Skip to content

Add fp4 support(WIP) #3496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/keith/pre-commit-buildifier
rev: 6.4.0
rev: 8.0.3
hooks:
- id: buildifier
args:
- --warnings=all
- id: buildifier-lint
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.23
rev: v0.24.1
hooks:
- id: validate-pyproject
- repo: https://github.com/pycqa/isort
Expand All @@ -37,17 +37,17 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.15.0"
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
rev: v0.11.7
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 25.1.0
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand All @@ -57,7 +57,7 @@ repos:
- id: typos
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.5
rev: 0.7.1
hooks:
# Update the uv lockfile
- id: uv-lock
Expand Down
4 changes: 4 additions & 0 deletions examples/dynamo/vgg16_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def calibrate_loop(model):
quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
quant_cfg = mtq.FP8_DEFAULT_CFG
elif args.quantize_type == "fp4":
quant_cfg = mtq.NVFP4_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
Expand Down Expand Up @@ -239,6 +241,8 @@ def calibrate_loop(model):
enabled_precisions = {torch.int8}
elif args.quantize_type == "fp8":
enabled_precisions = {torch.float8_e4m3fn}
elif args.quantize_type == "fp4":
enabled_precisions = {torch.float4_e2m1fn_x2}
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
Expand Down
19 changes: 19 additions & 0 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class dtype(Enum):
:meta hide-value:
"""

f4 = auto()
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``

:meta hide-value:
"""

uint8 = u8
int8 = i8

Expand All @@ -91,6 +97,9 @@ class dtype(Enum):
float8 = f8
fp8 = f8

float4 = f4
fp4 = f4

half = f16
fp16 = f16
float16 = f16
Expand Down Expand Up @@ -162,6 +171,8 @@ def _from(
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.float4_e2m1fn_x2:
return dtype.f4
elif t == torch.half:
return dtype.f16
elif t == torch.float:
Expand All @@ -188,6 +199,8 @@ def _from(
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.FP4:
return dtype.fp4
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
Expand Down Expand Up @@ -357,6 +370,8 @@ def to(
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f4:
return torch.float4_e2m1fn_x2
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
Expand Down Expand Up @@ -394,6 +409,8 @@ def to(
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif self == dtype.f4:
return trt.DataType.FP4
elif use_default:
return trt.DataType.FLOAT
else:
Expand All @@ -410,6 +427,8 @@ def to(
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f4:
return np.float4_e2m1fn_x2
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
SUPPORTED_KERNEL_PRECISIONS = {
dtype.f32,
dtype.f16,
dtype.bf16,
dtype.i8,
dtype.f8,
dtype.f4,
}
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
Expand Down
33 changes: 33 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,39 @@ def aten_ops_quantize_op(
)


try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models"
)
else:

@dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.nvfp4_quantize.nvfp4_quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
def aten_ops_squeeze(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
matmul,
nccl_ops,
normalization,
nvfp4_quantize,
pad,
permutation,
pool,
Expand Down
Loading
Loading