From c9703a7103b2838e897406774c86001d741b1d90 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 20 Sep 2023 15:28:08 -0700 Subject: [PATCH 1/7] update batch_norm and layer_norm --- .../dynamo/conversion/aten_ops_converters.py | 60 ++++++++++--------- .../conversion/impl/normalization/ops.py | 10 ++-- ...chnorm_aten.py => test_batch_norm_aten.py} | 0 3 files changed, 37 insertions(+), 33 deletions(-) rename tests/py/dynamo/conversion/{test_batchnorm_aten.py => test_batch_norm_aten.py} (100%) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 873f531b71..162bccc999 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -59,14 +59,37 @@ def aten_ops_batch_norm( target, SourceIR.ATEN, name, - args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6], - args[7], + input=args[0], + weight=args_bounds_check(args, 1, replacement=1), + bias=args_bounds_check(args, 2, replacement=0), + running_mean=args_bounds_check(args, 3), + running_var=args_bounds_check(args, 4), + training=args_bounds_check(args, 5), + momentum=args_bounds_check(args, 6, replacement=0.1), + eps=args_bounds_check(args, 7, replacement=1e-05), + cudnn_enabled=args_bounds_check(args, 8, replacement=False), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] +def aten_ops_layer_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.layer_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + normalized_shape=args[1], + weight=args_bounds_check(args, 2, replacement=1), + bias=args_bounds_check(args, 3, replacement=0), + eps=args_bounds_check(args, 4, replacement=1e-05), + cudnn_enable=args_bounds_check(args, 5, replacement=True), ) @@ -328,27 +351,6 @@ def aten_ops_matmul( ) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] -def aten_ops_layernorm( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.layer_norm( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - args[2], - args[3], - args[4], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc] def aten_ops_rsqrt( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 81bd88cd4f..2453bc77db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -36,6 +36,7 @@ def batch_norm( training: torch.Tensor, momentum: torch.Tensor, eps: List[float], + cudnn_enabled: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, TRTTensor): raise RuntimeError( @@ -69,7 +70,7 @@ def batch_norm( input.shape[2], 1, ) - set_layer_name(reshape_layer, target, f"{name}_reshape_2d") + set_layer_name(reshape_layer, target, f"{name}_reshape_2d", source_ir) input = reshape_layer.get_output(0) layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) set_layer_name(layer, target, name) @@ -78,7 +79,7 @@ def batch_norm( if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4: reshape_output_layer = ctx.net.add_shuffle(layer.get_output(0)) reshape_output_layer.reshape_dims = tuple(output_shape) - set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d") + set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d", source_ir) layer = reshape_output_layer return layer.get_output(0) @@ -93,6 +94,7 @@ def layer_norm( weight: torch.Tensor, bias: torch.Tensor, eps: List[float], + cudnn_enable: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, trt.tensorrt.ITensor): raise RuntimeError( @@ -173,7 +175,7 @@ def layer_norm_no_plugin( mean_expected_layer = ctx.net.add_reduce( input, trt.ReduceOperation.AVG, axes, keep_dims=True ) - set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") + set_layer_name(mean_expected_layer, target, f"{name}_mean_expected", source_ir) # X-E[x] sub_trt = convert_binary_elementwise( @@ -203,7 +205,7 @@ def layer_norm_no_plugin( mean_trt_layer = ctx.net.add_reduce( pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True ) - set_layer_name(mean_trt_layer, target, f"{name}_mean") + set_layer_name(mean_trt_layer, target, f"{name}_mean", source_ir) # Variance + eps eps_tensor = ctx.net.add_constant( (1,) * len(input.shape), diff --git a/tests/py/dynamo/conversion/test_batchnorm_aten.py b/tests/py/dynamo/conversion/test_batch_norm_aten.py similarity index 100% rename from tests/py/dynamo/conversion/test_batchnorm_aten.py rename to tests/py/dynamo/conversion/test_batch_norm_aten.py From c977b21134a4fe359ba215dde913b098a407738d Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 22 Sep 2023 16:53:50 -0700 Subject: [PATCH 2/7] fix bugs --- .../dynamo/conversion/aten_ops_converters.py | 24 +++++---- .../conversion/impl/normalization/ops.py | 52 ++++++++++++++----- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 162bccc999..6f3a26922a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -46,6 +46,7 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN +@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] def aten_ops_batch_norm( ctx: ConversionContext, @@ -60,17 +61,18 @@ def aten_ops_batch_norm( SourceIR.ATEN, name, input=args[0], - weight=args_bounds_check(args, 1, replacement=1), - bias=args_bounds_check(args, 2, replacement=0), - running_mean=args_bounds_check(args, 3), - running_var=args_bounds_check(args, 4), - training=args_bounds_check(args, 5), - momentum=args_bounds_check(args, 6, replacement=0.1), - eps=args_bounds_check(args, 7, replacement=1e-05), - cudnn_enabled=args_bounds_check(args, 8, replacement=False), + weight=args[1], + bias=args[2], + running_mean=args[3], + running_var=args[4], + training=args[5], + momentum=args[6], + eps=args[7], + cudnn_enabled=args_bounds_check(args, 8, replacement=True), ) +@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] def aten_ops_layer_norm( ctx: ConversionContext, @@ -86,9 +88,9 @@ def aten_ops_layer_norm( name, input=args[0], normalized_shape=args[1], - weight=args_bounds_check(args, 2, replacement=1), - bias=args_bounds_check(args, 3, replacement=0), - eps=args_bounds_check(args, 4, replacement=1e-05), + weight=args[2], + bias=args[3], + eps=args[4], cudnn_enable=args_bounds_check(args, 5, replacement=True), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 2453bc77db..a2cbb3520b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -29,13 +29,13 @@ def batch_norm( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - weight: torch.Tensor, - bias: torch.Tensor, - running_mean: torch.Tensor, - running_var: torch.Tensor, - training: torch.Tensor, - momentum: torch.Tensor, - eps: List[float], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + training: bool, + momentum: float, + eps: float, cudnn_enabled: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, TRTTensor): @@ -47,8 +47,20 @@ def batch_norm( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." + if weight is None: + weight = np.array(1.0) + + if bias is None: + bias = np.array(0.0) + + if running_mean is None: + running_mean = np.array(0.0) + + if running_var is None: + running_var = np.array(1.0) + scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( - cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps) + cast(torch.Tensor, to_numpy(running_var)) + eps ) bias = to_numpy(bias) - to_numpy(running_mean) * scale @@ -91,9 +103,9 @@ def layer_norm( name: str, input: TRTTensor, normalized_shape: List[int], - weight: torch.Tensor, - bias: torch.Tensor, - eps: List[float], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + eps: float, cudnn_enable: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, trt.tensorrt.ITensor): @@ -102,6 +114,12 @@ def layer_norm( "of the TensorRT region!" ) + if weight is None: + weight = np.array(1.0) + + if bias is None: + bias = np.array(0.0) + gamma = ( weight.detach().cpu().float().numpy() if isinstance(weight, torch.Tensor) @@ -152,9 +170,9 @@ def layer_norm_no_plugin( name: str, input: TRTTensor, normalized_shape: List[int], - weight: torch.Tensor, - bias: torch.Tensor, - eps: List[float], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + eps: float, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, TRTTensor): raise RuntimeError( @@ -162,6 +180,12 @@ def layer_norm_no_plugin( "of the TensorRT region!" ) + if weight is None: + weight = np.array(1.0) + + if bias is None: + bias = np.array(0.0) + shape = weight.shape broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape gamma = to_numpy(weight.reshape(*shape)) From 0b430f07be50fd2fdafcca22c6cc1b14c27b0772 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 25 Sep 2023 14:24:03 -0700 Subject: [PATCH 3/7] fix type bug support group norm, and improve batch and layer norms --- .../dynamo/conversion/aten_ops_converters.py | 99 +++++++++- .../conversion/impl/normalization/ops.py | 182 ++++++++++++++---- .../dynamo/conversion/ops_evaluators.py | 2 + .../dynamo/conversion/test_group_norm_aten.py | 52 +++++ 4 files changed, 290 insertions(+), 45 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_group_norm_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6f3a26922a..0141ab2da5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -47,6 +47,29 @@ def get_ir(target: Target) -> SourceIR: @dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc] +def aten_ops_native_batch_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.native_batch_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args[2], + running_mean=args[3], + running_var=args[4], + training=args[5], + momentum=args[6], + eps=args[7], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] def aten_ops_batch_norm( ctx: ConversionContext, @@ -68,20 +91,19 @@ def aten_ops_batch_norm( training=args[5], momentum=args[6], eps=args[7], - cudnn_enabled=args_bounds_check(args, 8, replacement=True), + cudnn_enabled=args[8], ) @dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc] -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] -def aten_ops_layer_norm( +def aten_ops_native_layer_norm( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.layer_norm( + return impl.normalization.native_layer_norm( ctx, target, SourceIR.ATEN, @@ -91,7 +113,74 @@ def aten_ops_layer_norm( weight=args[2], bias=args[3], eps=args[4], - cudnn_enable=args_bounds_check(args, 5, replacement=True), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] +def aten_ops_layer_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.layer_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + normalized_shape=args[1], + weight=args_bounds_check(args, 2), + bias=args_bounds_check(args, 3), + eps=args_bounds_check(args, 4, 1e-05), + cudnn_enable=args_bounds_check(args, 5, True), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default) # type: ignore[misc] +def aten_ops_native_group_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.native_group_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args[2], + N=args[3], + C=args[4], + HxW=args[5], + group=args[6], + eps=args[7], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc] +def aten_ops_group_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.group_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + num_groups=args[1], + weight=args_bounds_check(args, 2, None), + bias=args_bounds_check(args, 3, None), + eps=args_bounds_check(args, 4, 1e-05), + cudnn_enabled=args_bounds_check(args, 5, True), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index a2cbb3520b..5f11a436d4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Optional, Sequence, Union, cast +from typing import Any, List, Optional, Sequence, Union, cast, Tuple import numpy as np import tensorrt as trt @@ -14,6 +14,7 @@ from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import ( get_trt_plugin, + get_trt_tensor, has_dynamic_shape, set_layer_name, ) @@ -23,41 +24,34 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -def batch_norm( +def native_batch_norm( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, - weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + running_mean: Optional[Union[torch.Tensor, np.ndarray]], + running_var: Optional[Union[torch.Tensor, np.ndarray]], training: bool, momentum: float, eps: float, - cudnn_enabled: bool, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"BatchNorm2d received input {input} that is not part " - "of the TensorRT region!" - ) - +) -> Tuple[TRTTensor, TRTTensor, TRTTensor]: if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." if weight is None: - weight = np.array(1.0) + weight = 1.0 if bias is None: - bias = np.array(0.0) + bias = 0.0 if running_mean is None: - running_mean = np.array(0.0) + running_mean = 0.0 if running_var is None: - running_var = np.array(1.0) + running_var = 1.0 scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( cast(torch.Tensor, to_numpy(running_var)) + eps @@ -93,7 +87,40 @@ def batch_norm( reshape_output_layer.reshape_dims = tuple(output_shape) set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d", source_ir) layer = reshape_output_layer - return layer.get_output(0) + + + # 1 / sqrt((var + eps)) + save_rstd = 1 / (torch.sqrt(running_var + eps)) + + # eps_tensor = ctx.net.add_constant( + # (1,) * len(running_var.shape), + # trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + # ) + + # eps_tensor = ctx.net.add_constant( + # (1,) * len(input.shape), + # trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + # ) + + # add_trt = convert_binary_elementwise( + # ctx, + # target, + # source_ir, + # f"{name}_add", + # running_var, + # eps_tensor, + # ) + + # sqrt_trt = convert_unary( + # ctx, + # target, + # source_ir, + # f"{name}_sqrt", + # add_trt, + # ) + + + return layer.get_output(0), running_mean, save_rstd def layer_norm( @@ -103,8 +130,8 @@ def layer_norm( name: str, input: TRTTensor, normalized_shape: List[int], - weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], eps: float, cudnn_enable: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: @@ -115,10 +142,10 @@ def layer_norm( ) if weight is None: - weight = np.array(1.0) + weight = to_numpy(1.0) if bias is None: - bias = np.array(0.0) + bias = to_numpy(0.0) gamma = ( weight.detach().cpu().float().numpy() @@ -170,8 +197,8 @@ def layer_norm_no_plugin( name: str, input: TRTTensor, normalized_shape: List[int], - weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], eps: float, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, TRTTensor): @@ -181,19 +208,18 @@ def layer_norm_no_plugin( ) if weight is None: - weight = np.array(1.0) + weight = to_numpy(1.0) if bias is None: - bias = np.array(0.0) + bias = to_numpy(0.0) shape = weight.shape broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape - gamma = to_numpy(weight.reshape(*shape)) - beta = to_numpy(bias.reshape(*shape)) + gamma = to_numpy(weight).reshape(shape) + beta = to_numpy(bias).reshape(shape) - axes = 0 - for d in range(len(shape)): - axes |= 1 << (len(input.shape) - d - 1) + dims = list(range(len(input.shape) - len(shape), len(input.shape))) + axes = get_axes_for_reduce_op(dims) # E[x] mean_expected_layer = ctx.net.add_reduce( @@ -207,7 +233,6 @@ def layer_norm_no_plugin( target, source_ir, f"{name}_sub", - trt.ElementWiseOperation.SUB, input, mean_expected_layer.get_output(0), ) @@ -222,7 +247,6 @@ def layer_norm_no_plugin( target, source_ir, f"{name}_pow_var", - trt.ElementWiseOperation.POW, sub_trt, pow_tensor.get_output(0), ) @@ -241,7 +265,6 @@ def layer_norm_no_plugin( target, source_ir, f"{name}_add", - trt.ElementWiseOperation.SUM, mean_trt_layer.get_output(0), eps_tensor.get_output(0), ) @@ -251,7 +274,6 @@ def layer_norm_no_plugin( target, source_ir, f"{name}_sqrt", - trt.UnaryOperation.SQRT, add_trt, ) # (x - E[x]) / sqrt((var + eps)) @@ -260,7 +282,6 @@ def layer_norm_no_plugin( target, source_ir, f"{name}_div_trt", - trt.ElementWiseOperation.DIV, sub_trt, sqrt_trt, ) @@ -270,18 +291,19 @@ def layer_norm_no_plugin( gamma.shape, trt.Weights(np.ascontiguousarray(gamma)) ) gamma_tensor.name = f"{name}_gamma" + assert beta is not None beta_tensor = ctx.net.add_constant( gamma.shape, trt.Weights(np.ascontiguousarray(beta)) ) beta_tensor.name = f"{name}_beta" + # y * gamma + beta scale_layer = convert_binary_elementwise( ctx, target, source_ir, f"{name}_scale", - trt.ElementWiseOperation.PROD, div_trt, gamma_tensor.get_output(0), ) @@ -290,12 +312,92 @@ def layer_norm_no_plugin( target, source_ir, name, - trt.ElementWiseOperation.SUM, - scale_layer, + scaled_y, beta_tensor.get_output(0), ) +def native_group_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + N: int, + C: int, + HxW: int, + group: int, + eps: float, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return group_norm( + ctx, + target, + source_ir, + name, + input, + group, + weight, + bias, + eps, + cudnn_enabled=True, + ) + + +def group_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + num_groups: int, + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + eps: float, + cudnn_enabled: bool, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if not isinstance(input, trt.tensorrt.ITensor): + raise RuntimeError( + f"LayerNorm received input {input} that is not part " + "of the TensorRT region!" + ) + + if weight is None: + weight = to_numpy(1.0) + + if bias is None: + bias = to_numpy(0.0) + + scale = get_trt_tensor(network, weight, "scale") + bias = get_trt_tensor(network, bias, "bias") + + eps_field = trt.PluginField( + "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 + ) + num_groups_filed = trt.PluginField( + "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 + ) + + field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) + + try: + # Here's the schema of the plugin: + # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml + plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") + except AssertionError: + _LOGGER.error( + "Unable to find group norm plugin, fall back to TensorRT implementation." + ) + + layer = network.add_plugin_v2([input, scale, bias], plugin) + set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) + + # PyTorch requires three return values: (out, mean, rstd) + dummy_tensor = torch.tensor(0) + return layer.get_output(0), dummy_tensor, dummy_tensor + + def softmax( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 5cd09a010c..7a980327a2 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -2,6 +2,7 @@ import operator from typing import Dict, Sequence, Tuple, Union +import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.types import TRTTensor @@ -20,6 +21,7 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators @dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.detach.default) # type: ignore[misc] def generic_evaluator( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py new file mode 100644 index 0000000000..718a692773 --- /dev/null +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -0,0 +1,52 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestGroupNormConverter(DispatchTestCase): + def test_groupnorm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(2, 6) + + def forward(self, x): + return self.gn(x) + + inputs = [torch.randn(1, 6, 224, 224)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.native_group_norm.default}, + disable_passes=True, + ) + + def test_groupnorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(2, 6) + + def forward(self, x): + return self.gn(x) + + input_specs = [ + Input( + shape=(-1, 6, 5), + dtype=torch.float32, + shape_ranges=[((2, 6, 5), (6, 6, 5), (10, 6, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.native_group_norm.default}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() From 928d17220b96ee966ae4d73a447651fabd5c629a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Sep 2023 16:53:04 -0700 Subject: [PATCH 4/7] update decomposition_groups --- py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index f1cfaae348..bdeef49168 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -99,15 +99,12 @@ aten.nan_to_num, aten.narrow, # TODO: Disable the below operators once freezing is done - aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, aten._native_batch_norm_legit_no_training, aten.native_dropout_backward, - aten.native_group_norm, aten.native_group_norm_backward, - aten.native_layer_norm, aten.native_layer_norm_backward, aten.new_empty, aten.new_full, From d8a7c2dcbba3c6c1819ca2d658bcb6c3d7636a1a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 29 Sep 2023 19:51:07 -0700 Subject: [PATCH 5/7] update group_norm with native ops --- .../conversion/impl/normalization/ops.py | 267 +++++++++++------- 1 file changed, 163 insertions(+), 104 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 5f11a436d4..1136c4acb5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -13,8 +13,7 @@ ) from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_plugin, - get_trt_tensor, + get_positive_dim, has_dynamic_shape, set_layer_name, ) @@ -53,10 +52,7 @@ def native_batch_norm( if running_var is None: running_var = 1.0 - scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( - cast(torch.Tensor, to_numpy(running_var)) + eps - ) - + scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps) bias = to_numpy(bias) - to_numpy(running_mean) * scale power = np.ones_like(scale) @@ -135,78 +131,6 @@ def layer_norm( eps: float, cudnn_enable: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, trt.tensorrt.ITensor): - raise RuntimeError( - f"LayerNorm received input {input} that is not part " - "of the TensorRT region!" - ) - - if weight is None: - weight = to_numpy(1.0) - - if bias is None: - bias = to_numpy(0.0) - - gamma = ( - weight.detach().cpu().float().numpy() - if isinstance(weight, torch.Tensor) - else weight - ) - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) - beta = ( - bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias - ) - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) - eps_field = trt.PluginField( - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 - ) - try: - normalized_shape_arr = np.array(normalized_shape, dtype=np.int32) - except TypeError: - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") - normalized_shape_arr = np.array([], dtype=np.int32) - - normalized_shape_filed = trt.PluginField( - "normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32 - ) - field_collection = trt.PluginFieldCollection( - [gamma_field, beta_field, eps_field, normalized_shape_filed] - ) - - try: - if ctx.net.has_implicit_batch_dimension: - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") - else: - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") - except AssertionError: - _LOGGER.error( - "Unable to find layer norm plugin, fall back to TensorRT implementation." - ) - return layer_norm_no_plugin( - ctx, target, source_ir, name, input, normalized_shape, weight, bias, eps - ) - layer = ctx.net.add_plugin_v2([input], plugin) - layer.name = name - return layer.get_output(0) - - -def layer_norm_no_plugin( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - normalized_shape: List[int], - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], - eps: float, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"LayerNorm received input {input} that is not part " - "of the TensorRT region!" - ) - if weight is None: weight = to_numpy(1.0) @@ -357,45 +281,180 @@ def group_norm( eps: float, cudnn_enabled: bool, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, trt.tensorrt.ITensor): - raise RuntimeError( - f"LayerNorm received input {input} that is not part " - "of the TensorRT region!" - ) - if weight is None: weight = to_numpy(1.0) if bias is None: bias = to_numpy(0.0) - scale = get_trt_tensor(network, weight, "scale") - bias = get_trt_tensor(network, bias, "bias") + assert ( + len(input.shape) >= 3 + ), f"The input dimension should not be less than 3, got {len(input.shape)}!" + B, C = input.shape[0], input.shape[1] - eps_field = trt.PluginField( - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 + # Groups are a subdivision of the channel dimension. + assert ( + C % num_groups == 0 + ), f"The num of channels ({C}) should be divisible by num_groups ({num_groups})!" + + # Normalize every group. + reshaped_input = impl.shuffle.reshape( + network, + target, + SourceIR.ATEN, + name, + input, + shape=(B * num_groups, -1), ) - num_groups_filed = trt.PluginField( - "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 + dim = ( + len(reshaped_input.shape) - 1 + ) # TODO: PR #2347 supported negtive dimension in reduce, could be -1 + + # E[X] + mean_trt = impl.reduce.mean( + network, + target, + SourceIR.ATEN, + f"{name}_mean", + reshaped_input, + dim=dim, + keepdim=True, ) - field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) + # X - E[X] + sub_trt = impl.elementwise.sub( + network, + target, + source_ir, + f"{name}_sub", + reshaped_input, + mean_trt, + ) - try: - # Here's the schema of the plugin: - # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml - plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") - except AssertionError: - _LOGGER.error( - "Unable to find group norm plugin, fall back to TensorRT implementation." - ) + # variance = mean(pow(sub_trt, 2)) + pow_layer = network.add_constant( + (1,) * len(sub_trt.shape), + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + ) + pow_layer.name = f"{name}_power" + + pow_var = impl.elementwise.pow( + network, + target, + source_ir, + f"{name}_pow", + sub_trt, + pow_layer.get_output(0), + ) + + var_trt = impl.reduce.mean( + network, + target, + SourceIR.ATEN, + f"{name}_mean_var", + pow_var, + dim=dim, + keepdim=True, + ) + + # sqrt((var + eps)) + eps_layer = network.add_constant( + (1,) * len(reshaped_input.shape), + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + ) + eps_layer.name = f"{name}_eps" + + add_trt = impl.elementwise.add( + network, + target, + source_ir, + f"{name}_add", + var_trt, + eps_layer.get_output(0), + ) + sqrt_trt = impl.unary.sqrt( + network, + target, + source_ir, + f"{name}_sqrt", + add_trt, + ) + + # (X - E[X]) / sqrt((var + eps)) + div_trt = impl.elementwise.div( + network, + target, + source_ir, + f"{name}_div", + sub_trt, + sqrt_trt, + ) + + # Apply per-channel scale and bias. + output = impl.shuffle.reshape( + network, + target, + SourceIR.ATEN, + f"{name}_reshape_div", + div_trt, + shape=input.shape, + ) + + weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) + + reshaped_weight = impl.shuffle.reshape( + network, + target, + SourceIR.ATEN, + f"{name}_reshape_weight", + weight, + shape=weight_bias_shape, + ) + + output = impl.elementwise.mul( + network, + target, + SourceIR.ATEN, + f"{name}_mul_scale", + output, + reshaped_weight, + ) - layer = network.add_plugin_v2([input, scale, bias], plugin) - set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) + reshaped_bias = impl.shuffle.reshape( + network, + target, + SourceIR.ATEN, + f"{name}_reshape_bias", + bias, + shape=weight_bias_shape, + ) + + add_trt = impl.elementwise.add( + network, + target, + source_ir, + f"{name}_add_bias", + output, + reshaped_bias, + ) + + # TODO: compute the last two return values + # const1_layer = network.add_constant( + # (1,) * len(sqrt_trt.shape), + # trt.Weights(np.ascontiguousarray([1.0], dtype=np.float32)), + # ) + # const1_layer.name = f"{name}_const1" + + # rsqrt_trt = impl.elementwise.div( + # network, + # target, + # source_ir, + # f"{name}_rsqrt", + # const1_layer.get_output(0), + # sqrt_trt, + # ) - # PyTorch requires three return values: (out, mean, rstd) - dummy_tensor = torch.tensor(0) - return layer.get_output(0), dummy_tensor, dummy_tensor + return add_trt, torch.tensor(0), torch.tensor(0) def softmax( From 4cc353d88e61434ba0f5308faf342e369d708d2f Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 6 Oct 2023 19:30:31 -0700 Subject: [PATCH 6/7] rebase and update three norms --- .../dynamo/conversion/aten_ops_converters.py | 87 ++-- .../conversion/impl/normalization/ops.py | 381 ++++++++---------- .../dynamo/conversion/test_batch_norm_aten.py | 148 +++++-- .../dynamo/conversion/test_group_norm_aten.py | 101 +++-- .../dynamo/conversion/test_layer_norm_aten.py | 81 +++- 5 files changed, 451 insertions(+), 347 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0141ab2da5..938db3e6e0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -46,31 +46,22 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN -@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc] -def aten_ops_native_batch_norm( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.native_batch_norm( - ctx, - target, - SourceIR.ATEN, - name, - input=args[0], - weight=args[1], - bias=args[2], - running_mean=args[3], - running_var=args[4], - training=args[5], - momentum=args[6], - eps=args[7], +def one_user_validator(node: Node) -> bool: + # Validate only one user, which is a getitem node that accesses the first element in the list + return ( + len(node.users) == 1 + and list(node.users)[0].target == operator.getitem + and list(node.users)[0].args[1] == 0 ) -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_batch_norm( ctx: ConversionContext, target: Target, @@ -91,32 +82,18 @@ def aten_ops_batch_norm( training=args[5], momentum=args[6], eps=args[7], - cudnn_enabled=args[8], - ) - - -@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc] -def aten_ops_native_layer_norm( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.native_layer_norm( - ctx, - target, - SourceIR.ATEN, - name, - input=args[0], - normalized_shape=args[1], - weight=args[2], - bias=args[3], - eps=args[4], + cudnn_enabled=args_bounds_check(args, 8, True), + return_mean_rstd=(target == torch.ops.aten.native_batch_norm.default), ) +@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_layer_norm( ctx: ConversionContext, target: Target, @@ -135,10 +112,16 @@ def aten_ops_layer_norm( bias=args_bounds_check(args, 3), eps=args_bounds_check(args, 4, 1e-05), cudnn_enable=args_bounds_check(args, 5, True), + return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default), ) -@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_native_group_norm( ctx: ConversionContext, target: Target, @@ -163,6 +146,11 @@ def aten_ops_native_group_norm( @dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_group_norm( ctx: ConversionContext, target: Target, @@ -856,15 +844,6 @@ def aten_ops_prod( ) -def one_user_validator(node: Node) -> bool: - # Validate only one user, which is a getitem node that accesses the first element in the list - return ( - len(node.users) == 1 - and list(node.users)[0].target == operator.getitem - and list(node.users)[0].args[1] == 0 - ) - - @dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc] def aten_ops_max( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 1136c4acb5..05821f8d90 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,29 +1,26 @@ -import logging -from typing import Any, List, Optional, Sequence, Union, cast, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union, cast import numpy as np import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy -from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( - convert_binary_elementwise, +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_positive_dim, + get_trt_tensor, + to_numpy, ) -from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import ( - get_positive_dim, has_dynamic_shape, set_layer_name, ) from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims -_LOGGER: logging.Logger = logging.getLogger(__name__) - -def native_batch_norm( +def batch_norm( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -36,7 +33,9 @@ def native_batch_norm( training: bool, momentum: float, eps: float, -) -> Tuple[TRTTensor, TRTTensor, TRTTensor]: + cudnn_enabled: bool, + return_mean_rstd: bool, +) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." @@ -62,61 +61,34 @@ def native_batch_norm( assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "BatchNorm1D with more than one dynamic dims is not currently supported." - reshape_layer = ctx.net.add_shuffle(input) - if len(input.shape) == 2: - reshape_layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1) - else: # len(input_val.shape) == 3 - reshape_layer.reshape_dims = ( - input.shape[0], - input.shape[1], - input.shape[2], - 1, - ) - set_layer_name(reshape_layer, target, f"{name}_reshape_2d", source_ir) - input = reshape_layer.get_output(0) + new_shape = ( + (input.shape[0], input.shape[1], 1, 1) + if len(input.shape) == 2 + else (input.shape[0], input.shape[1], input.shape[2], 1) + ) + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape + ) layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) # For BatchNorm1d, reshape output back to 1d if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4: - reshape_output_layer = ctx.net.add_shuffle(layer.get_output(0)) - reshape_output_layer.reshape_dims = tuple(output_shape) - set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d", source_ir) - layer = reshape_output_layer - - - # 1 / sqrt((var + eps)) - save_rstd = 1 / (torch.sqrt(running_var + eps)) - - # eps_tensor = ctx.net.add_constant( - # (1,) * len(running_var.shape), - # trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), - # ) - - # eps_tensor = ctx.net.add_constant( - # (1,) * len(input.shape), - # trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), - # ) - - # add_trt = convert_binary_elementwise( - # ctx, - # target, - # source_ir, - # f"{name}_add", - # running_var, - # eps_tensor, - # ) - - # sqrt_trt = convert_unary( - # ctx, - # target, - # source_ir, - # f"{name}_sqrt", - # add_trt, - # ) - - - return layer.get_output(0), running_mean, save_rstd + output = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_1d", + layer.get_output(0), + output_shape, + ) + + if return_mean_rstd: + # return fake mean and rstd for now + return output, None, None + + return output def layer_norm( @@ -130,7 +102,8 @@ def layer_norm( bias: Optional[Union[torch.Tensor, np.ndarray]], eps: float, cudnn_enable: bool, -) -> Union[TRTTensor, Sequence[TRTTensor]]: + return_mean_rstd: bool, +) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: if weight is None: weight = to_numpy(1.0) @@ -138,108 +111,96 @@ def layer_norm( bias = to_numpy(0.0) shape = weight.shape - broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape gamma = to_numpy(weight).reshape(shape) beta = to_numpy(bias).reshape(shape) dims = list(range(len(input.shape) - len(shape), len(input.shape))) - axes = get_axes_for_reduce_op(dims) # E[x] - mean_expected_layer = ctx.net.add_reduce( - input, trt.ReduceOperation.AVG, axes, keep_dims=True + mean_expected_trt = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean_expected", input, dims, True ) - set_layer_name(mean_expected_layer, target, f"{name}_mean_expected", source_ir) # X-E[x] - sub_trt = convert_binary_elementwise( + sub_trt = impl.elementwise.sub( ctx, target, source_ir, f"{name}_sub", input, - mean_expected_layer.get_output(0), - ) - # Variance = mean(pow(x_sub_mean,2)) - pow_tensor = ctx.net.add_constant( - (1,) * len(input.shape), - trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + mean_expected_trt, ) - pow_tensor.name = f"{name}_power" - pow_var = convert_binary_elementwise( + + # Variance = mean(pow(x_sub_mean, 2)) + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) + pow_var = impl.elementwise.pow( ctx, target, source_ir, f"{name}_pow_var", sub_trt, - pow_tensor.get_output(0), - ) - mean_trt_layer = ctx.net.add_reduce( - pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True + pow_trt, ) - set_layer_name(mean_trt_layer, target, f"{name}_mean", source_ir) - # Variance + eps - eps_tensor = ctx.net.add_constant( - (1,) * len(input.shape), - trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + mean_trt = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean", pow_var, dims, True ) - eps_tensor.name = f"{name}_eps" - add_trt = convert_binary_elementwise( + + # sqrt((var + eps)) + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) + add_trt = impl.elementwise.add( ctx, target, source_ir, f"{name}_add", - mean_trt_layer.get_output(0), - eps_tensor.get_output(0), + mean_trt, + eps_trt, ) - # SQRT((Var + eps)) - sqrt_trt = convert_unary( + sqrt_trt = impl.unary.sqrt( ctx, target, source_ir, f"{name}_sqrt", add_trt, ) - # (x - E[x]) / sqrt((var + eps)) - div_trt = convert_binary_elementwise( + + # (X - E[X]) / sqrt((var + eps)) + div_trt = impl.elementwise.div( ctx, target, source_ir, - f"{name}_div_trt", + f"{name}_div", sub_trt, sqrt_trt, ) - assert gamma is not None - gamma_tensor = ctx.net.add_constant( - gamma.shape, trt.Weights(np.ascontiguousarray(gamma)) - ) - gamma_tensor.name = f"{name}_gamma" - - assert beta is not None - beta_tensor = ctx.net.add_constant( - gamma.shape, trt.Weights(np.ascontiguousarray(beta)) - ) - beta_tensor.name = f"{name}_beta" + gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") + beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") # y * gamma + beta - scale_layer = convert_binary_elementwise( + scaled_y = impl.elementwise.mul( ctx, target, source_ir, - f"{name}_scale", + f"{name}_mul_gamma", div_trt, - gamma_tensor.get_output(0), + gamma_trt, ) - return convert_binary_elementwise( + + output = impl.elementwise.add( ctx, target, source_ir, - name, + f"{name}_add_beta", scaled_y, - beta_tensor.get_output(0), + beta_trt, ) + if return_mean_rstd: + # return fake mean and rstd for now + return output, None, None + + return output + def native_group_norm( ctx: ConversionContext, @@ -254,39 +215,8 @@ def native_group_norm( HxW: int, group: int, eps: float, + return_mean_rstd: bool = True, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return group_norm( - ctx, - target, - source_ir, - name, - input, - group, - weight, - bias, - eps, - cudnn_enabled=True, - ) - - -def group_norm( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - num_groups: int, - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], - eps: float, - cudnn_enabled: bool, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - if weight is None: - weight = to_numpy(1.0) - - if bias is None: - bias = to_numpy(0.0) - assert ( len(input.shape) >= 3 ), f"The input dimension should not be less than 3, got {len(input.shape)}!" @@ -294,36 +224,41 @@ def group_norm( # Groups are a subdivision of the channel dimension. assert ( - C % num_groups == 0 - ), f"The num of channels ({C}) should be divisible by num_groups ({num_groups})!" + C % group == 0 + ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" + + if weight is None: + weight = to_numpy(1.0) + + if bias is None: + bias = to_numpy(0.0) # Normalize every group. reshaped_input = impl.shuffle.reshape( - network, + ctx, target, - SourceIR.ATEN, + source_ir, name, input, - shape=(B * num_groups, -1), + (B * group, -1), ) - dim = ( - len(reshaped_input.shape) - 1 - ) # TODO: PR #2347 supported negtive dimension in reduce, could be -1 + + dim = 1 # E[X] mean_trt = impl.reduce.mean( - network, + ctx, target, - SourceIR.ATEN, + source_ir, f"{name}_mean", reshaped_input, - dim=dim, - keepdim=True, + dim, + True, ) # X - E[X] sub_trt = impl.elementwise.sub( - network, + ctx, target, source_ir, f"{name}_sub", @@ -332,57 +267,47 @@ def group_norm( ) # variance = mean(pow(sub_trt, 2)) - pow_layer = network.add_constant( - (1,) * len(sub_trt.shape), - trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), - ) - pow_layer.name = f"{name}_power" - + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) pow_var = impl.elementwise.pow( - network, + ctx, target, source_ir, f"{name}_pow", sub_trt, - pow_layer.get_output(0), + pow_trt, ) var_trt = impl.reduce.mean( - network, + ctx, target, - SourceIR.ATEN, + source_ir, f"{name}_mean_var", pow_var, - dim=dim, - keepdim=True, + dim, + True, ) # sqrt((var + eps)) - eps_layer = network.add_constant( - (1,) * len(reshaped_input.shape), - trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), - ) - eps_layer.name = f"{name}_eps" - + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) add_trt = impl.elementwise.add( - network, + ctx, target, source_ir, f"{name}_add", var_trt, - eps_layer.get_output(0), + eps_trt, ) sqrt_trt = impl.unary.sqrt( - network, + ctx, target, source_ir, f"{name}_sqrt", add_trt, ) - # (X - E[X]) / sqrt((var + eps)) + # y = (X - E[X]) / sqrt((var + eps)) div_trt = impl.elementwise.div( - network, + ctx, target, source_ir, f"{name}_div", @@ -390,71 +315,91 @@ def group_norm( sqrt_trt, ) - # Apply per-channel scale and bias. + # y * gamma + beta + gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") + beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") + output = impl.shuffle.reshape( - network, + ctx, target, - SourceIR.ATEN, + source_ir, f"{name}_reshape_div", div_trt, - shape=input.shape, + input.shape, ) weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) - reshaped_weight = impl.shuffle.reshape( - network, + reshaped_gamma = impl.shuffle.reshape( + ctx, target, - SourceIR.ATEN, - f"{name}_reshape_weight", - weight, - shape=weight_bias_shape, + source_ir, + f"{name}_reshape_gamma", + gamma_trt, + weight_bias_shape, ) output = impl.elementwise.mul( - network, + ctx, target, - SourceIR.ATEN, - f"{name}_mul_scale", + source_ir, + f"{name}_mul_gamma", output, - reshaped_weight, + reshaped_gamma, ) reshaped_bias = impl.shuffle.reshape( - network, + ctx, target, - SourceIR.ATEN, - f"{name}_reshape_bias", - bias, - shape=weight_bias_shape, + source_ir, + f"{name}_reshape_beta", + beta_trt, + weight_bias_shape, ) - add_trt = impl.elementwise.add( - network, + output = impl.elementwise.add( + ctx, target, source_ir, - f"{name}_add_bias", + f"{name}_add_beta", output, reshaped_bias, ) - # TODO: compute the last two return values - # const1_layer = network.add_constant( - # (1,) * len(sqrt_trt.shape), - # trt.Weights(np.ascontiguousarray([1.0], dtype=np.float32)), - # ) - # const1_layer.name = f"{name}_const1" + if return_mean_rstd: + # return fake mean and rstd for now + return output, None, None - # rsqrt_trt = impl.elementwise.div( - # network, - # target, - # source_ir, - # f"{name}_rsqrt", - # const1_layer.get_output(0), - # sqrt_trt, - # ) + return output - return add_trt, torch.tensor(0), torch.tensor(0) + +def group_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + num_groups: int, + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + eps: float, + cudnn_enabled: bool, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return native_group_norm( + ctx, + target, + source_ir, + name, + input, + weight, + bias, + 0, + 0, + 0, + num_groups, + eps, + return_mean_rstd=False, + ) def softmax( @@ -493,5 +438,5 @@ def get_softmax_dim(ndim: int) -> int: layer = ctx.net.add_softmax(input) layer.axes = 1 << dim - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_batch_norm_aten.py b/tests/py/dynamo/conversion/test_batch_norm_aten.py index cb946bcc40..680e2264d1 100644 --- a/tests/py/dynamo/conversion/test_batch_norm_aten.py +++ b/tests/py/dynamo/conversion/test_batch_norm_aten.py @@ -1,35 +1,48 @@ -import unittest - import torch from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input from .harness import DispatchTestCase +FEATURE_NUM = 3 + class TestBatchNormConverter(DispatchTestCase): - @unittest.skip("Pending ongoing work on batchnorm converter in Dynamo") def test_batchnorm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - + class BatchNorm(torch.nn.Module): def forward(self, x): - return self.bn(x) + return torch.ops.aten.batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) + self.run_test( + BatchNorm(), + inputs, + ) - @unittest.skip("Pending ongoing work on batchnorm converter in Dynamo") def test_batchnorm1d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm1d(3) - + class BatchNorm(torch.nn.Module): def forward(self, x): - return self.bn(x) + return torch.ops.aten.batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) input_specs = [ Input( @@ -40,18 +53,24 @@ def forward(self, x): ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + BatchNorm(), + input_specs, ) - @unittest.skip("Pending ongoing work on batchnorm converter in Dynamo") def test_batchnorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - + class BatchNorm(torch.nn.Module): def forward(self, x): - return self.bn(x) + return torch.ops.aten.batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) input_specs = [ Input( @@ -62,10 +81,85 @@ def forward(self, x): ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + BatchNorm(), + input_specs, + ) + + +class TestNativeBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + )[0] + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_batchnorm1d_with_dynamic_shape(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + )[0] + + input_specs = [ + Input( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + BatchNorm(), + input_specs, ) - # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + def test_batchnorm_with_dynamic_shape(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + )[0] + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + BatchNorm(), + input_specs, + ) if __name__ == "__main__": diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index 718a692773..1e3fec058d 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -6,47 +6,86 @@ class TestGroupNormConverter(DispatchTestCase): - def test_groupnorm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.gn = torch.nn.GroupNorm(2, 6) - + def test_groupnorm1d(self): + class GroupNorm(torch.nn.Module): def forward(self, x): - return self.gn(x) + return torch.ops.aten.group_norm.default( + x, + 2, + torch.ones((6,)), + torch.zeros((6,)), + 1e-05, + True, + ) - inputs = [torch.randn(1, 6, 224, 224)] + inputs = [torch.randn(3, 6, 224)] self.run_test( - TestModule(), + GroupNorm(), inputs, - expected_ops={torch.ops.aten.native_group_norm.default}, - disable_passes=True, ) - def test_groupnorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.gn = torch.nn.GroupNorm(2, 6) + def test_groupnorm2d(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.group_norm.default( + x, + 2, + torch.ones((6,)), + torch.zeros((6,)), + 1e-05, + True, + ) + + inputs = [torch.randn(3, 6, 224, 224)] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) + +class TestNativeGroupNormConverter(DispatchTestCase): + def test_groupnorm1d(self): + class GroupNorm(torch.nn.Module): def forward(self, x): - return self.gn(x) - - input_specs = [ - Input( - shape=(-1, 6, 5), - dtype=torch.float32, - shape_ranges=[((2, 6, 5), (6, 6, 5), (10, 6, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - expected_ops={torch.ops.aten.native_group_norm.default}, - disable_passes=True, + return torch.ops.aten.native_group_norm.default( + x, + torch.ones((6,)), + torch.zeros((6,)), + 3, + 6, + 224, + 2, + 1e-05, + )[0] + + inputs = [torch.randn(3, 6, 224)] + self.run_test( + GroupNorm(), + inputs, ) + def test_groupnorm2d(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_group_norm.default( + x, + torch.ones((6,)), + torch.zeros((6,)), + 3, + 6, + 224 * 224, + 2, + 1e-05, + )[0] + + inputs = [torch.randn(3, 6, 224, 224)] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 0cc374e307..8013768214 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -1,5 +1,3 @@ -import unittest - import torch from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -8,30 +6,78 @@ class TestLayerNormConverter(DispatchTestCase): - @unittest.skip("Pending ongoing work on layernorm converter in Dynamo") def test_layer_norm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.ln = torch.nn.LayerNorm([3, 224, 224]) - + class LayerNorm(torch.nn.Module): def forward(self, x): - return self.ln(x) + return torch.ops.aten.layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + True, + ) inputs = [torch.randn(1, 3, 224, 224)] self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + LayerNorm(), + inputs, ) - @unittest.skip("Pending ongoing work on layernorm converter in Dynamo") def test_layernorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.ln = torch.nn.LayerNorm([3, 224, 224]) + class LayerNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + True, + ) + + input_specs = [ + Input( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], + ), + ] + self.run_test_with_dynamic_shape( + LayerNorm(), + input_specs, + ) + + +class TestNativeLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class LayerNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + )[0] + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + LayerNorm(), + inputs, + ) + + def test_layernorm_with_dynamic_shape(self): + class LayerNorm(torch.nn.Module): def forward(self, x): - return self.ln(x) + return torch.ops.aten.native_layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + )[0] input_specs = [ Input( @@ -42,7 +88,8 @@ def forward(self, x): ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} + LayerNorm(), + input_specs, ) From 84b58dd73e22527cab52a70cd14840ec8337770b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 9 Oct 2023 10:56:06 -0700 Subject: [PATCH 7/7] add decorators --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 938db3e6e0..a0d2ce1fde 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -57,6 +57,7 @@ def one_user_validator(node: Node) -> bool: @dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,), @@ -89,6 +90,7 @@ def aten_ops_batch_norm( @dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,), @@ -146,6 +148,7 @@ def aten_ops_native_group_norm( @dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.group_norm) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,),