From ade7ed6aac1b75efdc90f6e8b1d644510c6ca1be Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 16 Apr 2024 12:55:29 -0700 Subject: [PATCH 1/2] refactor layernorm converter with INormalization Layer --- .../conversion/impl/normalization/ops.py | 111 ++++-------------- .../dynamo/conversion/test_layer_norm_aten.py | 51 ++++++-- 2 files changed, 64 insertions(+), 98 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index b1e4fbf24c..6db7ed667e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -9,6 +9,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, + get_axes_for_reduce_op, get_positive_dim, get_trt_tensor, to_numpy, @@ -105,102 +106,30 @@ def layer_norm( cudnn_enable: bool, return_mean_rstd: bool, ) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: - if weight is None: - weight = to_numpy(1.0) - - if bias is None: - bias = to_numpy(0.0) - - shape = weight.shape - gamma = to_numpy(weight).reshape(shape) - beta = to_numpy(bias).reshape(shape) - - dims = list(range(len(input.shape) - len(shape), len(input.shape))) - - # E[x] - mean_expected_trt = impl.reduce.mean( - ctx, target, source_ir, f"{name}_mean_expected", input, dims, True - ) - - # X-E[x] - sub_trt = impl.elementwise.sub( - ctx, - target, - source_ir, - f"{name}_sub", - input, - mean_expected_trt, - ) - - # Variance = mean(pow(x_sub_mean, 2)) - pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) - pow_var = impl.elementwise.pow( - ctx, - target, - source_ir, - f"{name}_pow_var", - sub_trt, - pow_trt, - ) - mean_trt = impl.reduce.mean( - ctx, target, source_ir, f"{name}_mean", pow_var, dims, True - ) - - # sqrt((var + eps)) - eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) - add_trt = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_add", - mean_trt, - eps_trt, - ) - sqrt_trt = impl.unary.sqrt( - ctx, - target, - source_ir, - f"{name}_sqrt", - add_trt, - ) - - # (X - E[X]) / sqrt((var + eps)) - div_trt = impl.elementwise.div( - ctx, - target, - source_ir, - f"{name}_div", - sub_trt, - sqrt_trt, - ) - - gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") - beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") - - # y * gamma + beta - scaled_y = impl.elementwise.mul( - ctx, - target, - source_ir, - f"{name}_mul_gamma", - div_trt, - gamma_trt, - ) + dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) + axes = get_axes_for_reduce_op(dims) + + weight = get_trt_tensor(ctx, weight, f"{name}_weight") + bias = get_trt_tensor(ctx, bias, f"{name}_bias") + if tuple(input.shape) != tuple(weight.shape): + weight = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape + ) + if tuple(input.shape) != tuple(bias.shape): + bias = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape + ) - output = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_add_beta", - scaled_y, - beta_trt, - ) + layer_norm = ctx.net.add_normalization(input, weight, bias, axes) + layer_norm.epsilon = eps + layer_norm.compute_precision = input.dtype + set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir) if return_mean_rstd: # return fake mean and rstd for now - return output, None, None + return layer_norm.get_output(0), None, None - return output + return layer_norm.get_output(0) def native_group_norm( diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 7f43234211..6b4a1d6961 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -1,4 +1,5 @@ import torch +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -6,19 +7,31 @@ class TestLayerNormConverter(DispatchTestCase): - def test_layer_norm(self): + @parameterized.expand( + [ + ( + (5, 3, 2, 4), + [ + 4, + ], + ), + ((5, 3, 2, 4), [2, 4]), + ((5, 3, 2, 4), [3, 2, 4]), + ((5, 3, 2, 4), [5, 3, 2, 4]), + ] + ) + def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.layer_norm.default( x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, - True, + normalized_shape, + torch.randn(normalized_shape), + torch.randn(normalized_shape), + eps, ) - inputs = [torch.randn(1, 3, 224, 224)] + inputs = [torch.randn(input_shape)] self.run_test( LayerNorm(), inputs, @@ -43,6 +56,30 @@ def forward(self, x): inputs, ) + def test_layernorm_with_dynamic_shape(self): + class LayerNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_layer_norm.default( + x, + torch.tensor([3, 224, 224]), + torch.ones((3, 224, 224)), + torch.zeros((3, 224, 224)), + 1e-05, + )[0] + + input_specs = [ + Input( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], + ), + ] + + self.run_test_with_dynamic_shape( + LayerNorm(), + input_specs, + ) + if __name__ == "__main__": run_tests() From fbed329feecff61554886c13deef033bfb0e8eaf Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 16 Apr 2024 13:01:31 -0700 Subject: [PATCH 2/2] add more test cases --- .../dynamo/conversion/test_layer_norm_aten.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 6b4a1d6961..c0e055304a 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -39,18 +39,31 @@ def forward(self, x): class TestNativeLayerNormConverter(DispatchTestCase): - def test_layer_norm(self): + @parameterized.expand( + [ + ( + (5, 3, 2, 4), + [ + 4, + ], + ), + ((5, 3, 2, 4), [2, 4]), + ((5, 3, 2, 4), [3, 2, 4]), + ((5, 3, 2, 4), [5, 3, 2, 4]), + ] + ) + def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, + normalized_shape, + torch.randn(normalized_shape), + torch.randn(normalized_shape), + eps, )[0] - inputs = [torch.randn(1, 3, 224, 224)] + inputs = [torch.randn(input_shape)] self.run_test( LayerNorm(), inputs,