Skip to content

fix: refactor layer norm converter with INormalization Layer #2755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 20 additions & 91 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_axes_for_reduce_op,
get_positive_dim,
get_trt_tensor,
to_numpy,
Expand Down Expand Up @@ -105,102 +106,30 @@ def layer_norm(
cudnn_enable: bool,
return_mean_rstd: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
if weight is None:
weight = to_numpy(1.0)

if bias is None:
bias = to_numpy(0.0)

shape = weight.shape
gamma = to_numpy(weight).reshape(shape)
beta = to_numpy(bias).reshape(shape)

dims = list(range(len(input.shape) - len(shape), len(input.shape)))

# E[x]
mean_expected_trt = impl.reduce.mean(
ctx, target, source_ir, f"{name}_mean_expected", input, dims, True
)

# X-E[x]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
input,
mean_expected_trt,
)

# Variance = mean(pow(x_sub_mean, 2))
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow_var",
sub_trt,
pow_trt,
)
mean_trt = impl.reduce.mean(
ctx, target, source_ir, f"{name}_mean", pow_var, dims, True
)

# sqrt((var + eps))
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
mean_trt,
eps_trt,
)
sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
)

# (X - E[X]) / sqrt((var + eps))
div_trt = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
)

gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma")
beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta")

# y * gamma + beta
scaled_y = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_gamma",
div_trt,
gamma_trt,
)
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
if tuple(input.shape) != tuple(weight.shape):
weight = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
)
if tuple(input.shape) != tuple(bias.shape):
bias = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
)

output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_beta",
scaled_y,
beta_trt,
)
layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
layer_norm.epsilon = eps
layer_norm.compute_precision = input.dtype
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)

if return_mean_rstd:
# return fake mean and rstd for now
return output, None, None
return layer_norm.get_output(0), None, None

return output
return layer_norm.get_output(0)


def native_group_norm(
Expand Down
72 changes: 61 additions & 11 deletions tests/py/dynamo/conversion/test_layer_norm_aten.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,75 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestLayerNormConverter(DispatchTestCase):
def test_layer_norm(self):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.layer_norm.default(
x,
torch.tensor([3, 224, 224]),
torch.ones((3, 224, 224)),
torch.zeros((3, 224, 224)),
1e-05,
True,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
)

inputs = [torch.randn(1, 3, 224, 224)]
inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)


class TestNativeLayerNormConverter(DispatchTestCase):
def test_layer_norm(self):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
)[0]

inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)

def test_layernorm_with_dynamic_shape(self):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
Expand All @@ -37,10 +80,17 @@ def forward(self, x):
1e-05,
)[0]

inputs = [torch.randn(1, 3, 224, 224)]
self.run_test(
input_specs = [
Input(
shape=(-1, 3, 224, 224),
dtype=torch.float32,
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))],
),
]

self.run_test_with_dynamic_shape(
LayerNorm(),
inputs,
input_specs,
)


Expand Down
Loading