diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 873f531b71..a0d2ce1fde 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -46,7 +46,23 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN +def one_user_validator(node: Node) -> bool: + # Validate only one user, which is a getitem node that accesses the first element in the list + return ( + len(node.users) == 1 + and list(node.users)[0].target == operator.getitem + and list(node.users)[0].args[1] == 0 + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] def aten_ops_batch_norm( ctx: ConversionContext, target: Target, @@ -59,14 +75,103 @@ def aten_ops_batch_norm( target, SourceIR.ATEN, name, - args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6], - args[7], + input=args[0], + weight=args[1], + bias=args[2], + running_mean=args[3], + running_var=args[4], + training=args[5], + momentum=args[6], + eps=args[7], + cudnn_enabled=args_bounds_check(args, 8, True), + return_mean_rstd=(target == torch.ops.aten.native_batch_norm.default), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_layer_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.layer_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + normalized_shape=args[1], + weight=args_bounds_check(args, 2), + bias=args_bounds_check(args, 3), + eps=args_bounds_check(args, 4, 1e-05), + cudnn_enable=args_bounds_check(args, 5, True), + return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_native_group_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.native_group_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args[2], + N=args[3], + C=args[4], + HxW=args[5], + group=args[6], + eps=args[7], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.group_norm) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_group_norm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.group_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + num_groups=args[1], + weight=args_bounds_check(args, 2, None), + bias=args_bounds_check(args, 3, None), + eps=args_bounds_check(args, 4, 1e-05), + cudnn_enabled=args_bounds_check(args, 5, True), ) @@ -328,27 +433,6 @@ def aten_ops_matmul( ) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc] -def aten_ops_layernorm( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.layer_norm( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - args[2], - args[3], - args[4], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc] def aten_ops_rsqrt( ctx: ConversionContext, @@ -763,15 +847,6 @@ def aten_ops_prod( ) -def one_user_validator(node: Node) -> bool: - # Validate only one user, which is a getitem node that accesses the first element in the list - return ( - len(node.users) == 1 - and list(node.users)[0].target == operator.getitem - and list(node.users)[0].args[1] == 0 - ) - - @dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc] def aten_ops_max( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 81bd88cd4f..05821f8d90 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,27 +1,24 @@ -import logging -from typing import Any, List, Optional, Sequence, Union, cast +from typing import Any, List, Optional, Sequence, Tuple, Union, cast import numpy as np import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy -from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( - convert_binary_elementwise, +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_positive_dim, + get_trt_tensor, + to_numpy, ) -from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_plugin, has_dynamic_shape, set_layer_name, ) from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims -_LOGGER: logging.Logger = logging.getLogger(__name__) - def batch_norm( ctx: ConversionContext, @@ -29,27 +26,32 @@ def batch_norm( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - weight: torch.Tensor, - bias: torch.Tensor, - running_mean: torch.Tensor, - running_var: torch.Tensor, - training: torch.Tensor, - momentum: torch.Tensor, - eps: List[float], -) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"BatchNorm2d received input {input} that is not part " - "of the TensorRT region!" - ) - + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + running_mean: Optional[Union[torch.Tensor, np.ndarray]], + running_var: Optional[Union[torch.Tensor, np.ndarray]], + training: bool, + momentum: float, + eps: float, + cudnn_enabled: bool, + return_mean_rstd: bool, +) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." - scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( - cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps) - ) + if weight is None: + weight = 1.0 + + if bias is None: + bias = 0.0 + if running_mean is None: + running_mean = 0.0 + + if running_var is None: + running_var = 1.0 + + scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps) bias = to_numpy(bias) - to_numpy(running_mean) * scale power = np.ones_like(scale) @@ -59,28 +61,34 @@ def batch_norm( assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "BatchNorm1D with more than one dynamic dims is not currently supported." - reshape_layer = ctx.net.add_shuffle(input) - if len(input.shape) == 2: - reshape_layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1) - else: # len(input_val.shape) == 3 - reshape_layer.reshape_dims = ( - input.shape[0], - input.shape[1], - input.shape[2], - 1, - ) - set_layer_name(reshape_layer, target, f"{name}_reshape_2d") - input = reshape_layer.get_output(0) + new_shape = ( + (input.shape[0], input.shape[1], 1, 1) + if len(input.shape) == 2 + else (input.shape[0], input.shape[1], input.shape[2], 1) + ) + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape + ) layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) # For BatchNorm1d, reshape output back to 1d if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4: - reshape_output_layer = ctx.net.add_shuffle(layer.get_output(0)) - reshape_output_layer.reshape_dims = tuple(output_shape) - set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d") - layer = reshape_output_layer - return layer.get_output(0) + output = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_1d", + layer.get_output(0), + output_shape, + ) + + if return_mean_rstd: + # return fake mean and rstd for now + return output, None, None + + return output def layer_norm( @@ -90,183 +98,307 @@ def layer_norm( name: str, input: TRTTensor, normalized_shape: List[int], - weight: torch.Tensor, - bias: torch.Tensor, - eps: List[float], -) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, trt.tensorrt.ITensor): - raise RuntimeError( - f"LayerNorm received input {input} that is not part " - "of the TensorRT region!" - ) + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + eps: float, + cudnn_enable: bool, + return_mean_rstd: bool, +) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: + if weight is None: + weight = to_numpy(1.0) - gamma = ( - weight.detach().cpu().float().numpy() - if isinstance(weight, torch.Tensor) - else weight + if bias is None: + bias = to_numpy(0.0) + + shape = weight.shape + gamma = to_numpy(weight).reshape(shape) + beta = to_numpy(bias).reshape(shape) + + dims = list(range(len(input.shape) - len(shape), len(input.shape))) + + # E[x] + mean_expected_trt = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean_expected", input, dims, True ) - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) - beta = ( - bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias + + # X-E[x] + sub_trt = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_sub", + input, + mean_expected_trt, ) - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) - eps_field = trt.PluginField( - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 + + # Variance = mean(pow(x_sub_mean, 2)) + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) + pow_var = impl.elementwise.pow( + ctx, + target, + source_ir, + f"{name}_pow_var", + sub_trt, + pow_trt, ) - try: - normalized_shape_arr = np.array(normalized_shape, dtype=np.int32) - except TypeError: - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") - normalized_shape_arr = np.array([], dtype=np.int32) - - normalized_shape_filed = trt.PluginField( - "normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32 + mean_trt = impl.reduce.mean( + ctx, target, source_ir, f"{name}_mean", pow_var, dims, True + ) + + # sqrt((var + eps)) + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) + add_trt = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add", + mean_trt, + eps_trt, ) - field_collection = trt.PluginFieldCollection( - [gamma_field, beta_field, eps_field, normalized_shape_filed] + sqrt_trt = impl.unary.sqrt( + ctx, + target, + source_ir, + f"{name}_sqrt", + add_trt, ) - try: - if ctx.net.has_implicit_batch_dimension: - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") - else: - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") - except AssertionError: - _LOGGER.error( - "Unable to find layer norm plugin, fall back to TensorRT implementation." - ) - return layer_norm_no_plugin( - ctx, target, source_ir, name, input, normalized_shape, weight, bias, eps - ) - layer = ctx.net.add_plugin_v2([input], plugin) - layer.name = name - return layer.get_output(0) + # (X - E[X]) / sqrt((var + eps)) + div_trt = impl.elementwise.div( + ctx, + target, + source_ir, + f"{name}_div", + sub_trt, + sqrt_trt, + ) + gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") + beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") -def layer_norm_no_plugin( + # y * gamma + beta + scaled_y = impl.elementwise.mul( + ctx, + target, + source_ir, + f"{name}_mul_gamma", + div_trt, + gamma_trt, + ) + + output = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add_beta", + scaled_y, + beta_trt, + ) + + if return_mean_rstd: + # return fake mean and rstd for now + return output, None, None + + return output + + +def native_group_norm( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, - normalized_shape: List[int], - weight: torch.Tensor, - bias: torch.Tensor, - eps: List[float], + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + N: int, + C: int, + HxW: int, + group: int, + eps: float, + return_mean_rstd: bool = True, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"LayerNorm received input {input} that is not part " - "of the TensorRT region!" - ) + assert ( + len(input.shape) >= 3 + ), f"The input dimension should not be less than 3, got {len(input.shape)}!" + B, C = input.shape[0], input.shape[1] - shape = weight.shape - broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape - gamma = to_numpy(weight.reshape(*shape)) - beta = to_numpy(bias.reshape(*shape)) + # Groups are a subdivision of the channel dimension. + assert ( + C % group == 0 + ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" - axes = 0 - for d in range(len(shape)): - axes |= 1 << (len(input.shape) - d - 1) + if weight is None: + weight = to_numpy(1.0) - # E[x] - mean_expected_layer = ctx.net.add_reduce( - input, trt.ReduceOperation.AVG, axes, keep_dims=True - ) - set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") + if bias is None: + bias = to_numpy(0.0) - # X-E[x] - sub_trt = convert_binary_elementwise( + # Normalize every group. + reshaped_input = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_sub", - trt.ElementWiseOperation.SUB, + name, input, - mean_expected_layer.get_output(0), + (B * group, -1), ) - # Variance = mean(pow(x_sub_mean,2)) - pow_tensor = ctx.net.add_constant( - (1,) * len(input.shape), - trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + + dim = 1 + + # E[X] + mean_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean", + reshaped_input, + dim, + True, ) - pow_tensor.name = f"{name}_power" - pow_var = convert_binary_elementwise( + + # X - E[X] + sub_trt = impl.elementwise.sub( ctx, target, source_ir, - f"{name}_pow_var", - trt.ElementWiseOperation.POW, - sub_trt, - pow_tensor.get_output(0), + f"{name}_sub", + reshaped_input, + mean_trt, ) - mean_trt_layer = ctx.net.add_reduce( - pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True + + # variance = mean(pow(sub_trt, 2)) + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) + pow_var = impl.elementwise.pow( + ctx, + target, + source_ir, + f"{name}_pow", + sub_trt, + pow_trt, ) - set_layer_name(mean_trt_layer, target, f"{name}_mean") - # Variance + eps - eps_tensor = ctx.net.add_constant( - (1,) * len(input.shape), - trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + + var_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean_var", + pow_var, + dim, + True, ) - eps_tensor.name = f"{name}_eps" - add_trt = convert_binary_elementwise( + + # sqrt((var + eps)) + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) + add_trt = impl.elementwise.add( ctx, target, source_ir, f"{name}_add", - trt.ElementWiseOperation.SUM, - mean_trt_layer.get_output(0), - eps_tensor.get_output(0), + var_trt, + eps_trt, ) - # SQRT((Var + eps)) - sqrt_trt = convert_unary( + sqrt_trt = impl.unary.sqrt( ctx, target, source_ir, f"{name}_sqrt", - trt.UnaryOperation.SQRT, add_trt, ) - # (x - E[x]) / sqrt((var + eps)) - div_trt = convert_binary_elementwise( + + # y = (X - E[X]) / sqrt((var + eps)) + div_trt = impl.elementwise.div( ctx, target, source_ir, - f"{name}_div_trt", - trt.ElementWiseOperation.DIV, + f"{name}_div", sub_trt, sqrt_trt, ) - assert gamma is not None - gamma_tensor = ctx.net.add_constant( - gamma.shape, trt.Weights(np.ascontiguousarray(gamma)) - ) - gamma_tensor.name = f"{name}_gamma" - assert beta is not None - beta_tensor = ctx.net.add_constant( - gamma.shape, trt.Weights(np.ascontiguousarray(beta)) - ) - beta_tensor.name = f"{name}_beta" # y * gamma + beta - scale_layer = convert_binary_elementwise( + gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") + beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") + + output = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_scale", - trt.ElementWiseOperation.PROD, + f"{name}_reshape_div", div_trt, - gamma_tensor.get_output(0), + input.shape, + ) + + weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) + + reshaped_gamma = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_gamma", + gamma_trt, + weight_bias_shape, ) - return convert_binary_elementwise( + + output = impl.elementwise.mul( + ctx, + target, + source_ir, + f"{name}_mul_gamma", + output, + reshaped_gamma, + ) + + reshaped_bias = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_beta", + beta_trt, + weight_bias_shape, + ) + + output = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add_beta", + output, + reshaped_bias, + ) + + if return_mean_rstd: + # return fake mean and rstd for now + return output, None, None + + return output + + +def group_norm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + num_groups: int, + weight: Optional[Union[torch.Tensor, np.ndarray]], + bias: Optional[Union[torch.Tensor, np.ndarray]], + eps: float, + cudnn_enabled: bool, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return native_group_norm( ctx, target, source_ir, name, - trt.ElementWiseOperation.SUM, - scale_layer, - beta_tensor.get_output(0), + input, + weight, + bias, + 0, + 0, + 0, + num_groups, + eps, + return_mean_rstd=False, ) @@ -306,5 +438,5 @@ def get_softmax_dim(ndim: int) -> int: layer = ctx.net.add_softmax(input) layer.axes = 1 << dim - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 5cd09a010c..7a980327a2 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -2,6 +2,7 @@ import operator from typing import Dict, Sequence, Tuple, Union +import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.types import TRTTensor @@ -20,6 +21,7 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators @dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.detach.default) # type: ignore[misc] def generic_evaluator( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index f1cfaae348..bdeef49168 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -99,15 +99,12 @@ aten.nan_to_num, aten.narrow, # TODO: Disable the below operators once freezing is done - aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, aten._native_batch_norm_legit_no_training, aten.native_dropout_backward, - aten.native_group_norm, aten.native_group_norm_backward, - aten.native_layer_norm, aten.native_layer_norm_backward, aten.new_empty, aten.new_full, diff --git a/tests/py/dynamo/conversion/test_batch_norm_aten.py b/tests/py/dynamo/conversion/test_batch_norm_aten.py new file mode 100644 index 0000000000..680e2264d1 --- /dev/null +++ b/tests/py/dynamo/conversion/test_batch_norm_aten.py @@ -0,0 +1,166 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + +FEATURE_NUM = 3 + + +class TestBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_batchnorm1d_with_dynamic_shape(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) + + input_specs = [ + Input( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + BatchNorm(), + input_specs, + ) + + def test_batchnorm_with_dynamic_shape(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + BatchNorm(), + input_specs, + ) + + +class TestNativeBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + )[0] + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_batchnorm1d_with_dynamic_shape(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + )[0] + + input_specs = [ + Input( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + BatchNorm(), + input_specs, + ) + + def test_batchnorm_with_dynamic_shape(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_batch_norm.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + )[0] + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + BatchNorm(), + input_specs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_batchnorm_aten.py b/tests/py/dynamo/conversion/test_batchnorm_aten.py deleted file mode 100644 index cb946bcc40..0000000000 --- a/tests/py/dynamo/conversion/test_batchnorm_aten.py +++ /dev/null @@ -1,72 +0,0 @@ -import unittest - -import torch -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input - -from .harness import DispatchTestCase - - -class TestBatchNormConverter(DispatchTestCase): - @unittest.skip("Pending ongoing work on batchnorm converter in Dynamo") - def test_batchnorm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - return self.bn(x) - - inputs = [torch.randn(1, 3, 224, 224)] - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) - - @unittest.skip("Pending ongoing work on batchnorm converter in Dynamo") - def test_batchnorm1d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm1d(3) - - def forward(self, x): - return self.bn(x) - - input_specs = [ - Input( - shape=(-1, 3, 5), - dtype=torch.float32, - shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} - ) - - @unittest.skip("Pending ongoing work on batchnorm converter in Dynamo") - def test_batchnorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - return self.bn(x) - - input_specs = [ - Input( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} - ) - - # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. - - -if __name__ == "__main__": - run_tests() diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py new file mode 100644 index 0000000000..1e3fec058d --- /dev/null +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -0,0 +1,91 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestGroupNormConverter(DispatchTestCase): + def test_groupnorm1d(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.group_norm.default( + x, + 2, + torch.ones((6,)), + torch.zeros((6,)), + 1e-05, + True, + ) + + inputs = [torch.randn(3, 6, 224)] + self.run_test( + GroupNorm(), + inputs, + ) + + def test_groupnorm2d(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.group_norm.default( + x, + 2, + torch.ones((6,)), + torch.zeros((6,)), + 1e-05, + True, + ) + + inputs = [torch.randn(3, 6, 224, 224)] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) + + +class TestNativeGroupNormConverter(DispatchTestCase): + def test_groupnorm1d(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_group_norm.default( + x, + torch.ones((6,)), + torch.zeros((6,)), + 3, + 6, + 224, + 2, + 1e-05, + )[0] + + inputs = [torch.randn(3, 6, 224)] + self.run_test( + GroupNorm(), + inputs, + ) + + def test_groupnorm2d(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_group_norm.default( + x, + torch.ones((6,)), + torch.zeros((6,)), + 3, + 6, + 224 * 224, + 2, + 1e-05, + )[0] + + inputs = [torch.randn(3, 6, 224, 224)] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 0cc374e307..8013768214 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -1,5 +1,3 @@ -import unittest - import torch from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -8,30 +6,78 @@ class TestLayerNormConverter(DispatchTestCase): - @unittest.skip("Pending ongoing work on layernorm converter in Dynamo") def test_layer_norm(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.ln = torch.nn.LayerNorm([3, 224, 224]) - + class LayerNorm(torch.nn.Module): def forward(self, x): - return self.ln(x) + return torch.ops.aten.layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + True, + ) inputs = [torch.randn(1, 3, 224, 224)] self.run_test( - TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + LayerNorm(), + inputs, ) - @unittest.skip("Pending ongoing work on layernorm converter in Dynamo") def test_layernorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.ln = torch.nn.LayerNorm([3, 224, 224]) + class LayerNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + True, + ) + + input_specs = [ + Input( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], + ), + ] + self.run_test_with_dynamic_shape( + LayerNorm(), + input_specs, + ) + + +class TestNativeLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class LayerNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + )[0] + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + LayerNorm(), + inputs, + ) + + def test_layernorm_with_dynamic_shape(self): + class LayerNorm(torch.nn.Module): def forward(self, x): - return self.ln(x) + return torch.ops.aten.native_layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + )[0] input_specs = [ Input( @@ -42,7 +88,8 @@ def forward(self, x): ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} + LayerNorm(), + input_specs, )