diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index eac8a4b70c..4f2d168d29 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -7,7 +7,6 @@ import numpy as np import torch from torch.fx.node import Argument, Node, Target - from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -650,6 +649,11 @@ def aten_ops_erf( @dynamo_tensorrt_converter( torch.ops.aten.unsqueeze.default, supports_dynamic_shapes=True ) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_unsqueeze( ctx: ConversionContext, target: Target, @@ -657,9 +661,7 @@ def aten_ops_unsqueeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.unsqueeze.unsqueeze( - ctx, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1] - ) + return impl.unsqueeze.unsqueeze(ctx, target, SourceIR.ATEN, name, args[0], args[1]) @dynamo_tensorrt_converter( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 0db246a7c6..3dacc2fbe4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -1,14 +1,13 @@ -from typing import List, Optional, Sequence, cast +from typing import List, Optional, Sequence from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( - get_positive_dim, get_trt_tensor, + set_layer_name, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import Shape, TRTTensor +from torch_tensorrt.dynamo.types import TRTTensor def unsqueeze( @@ -16,64 +15,11 @@ def unsqueeze( target: Target, source_ir: Optional[SourceIR], name: str, - input_t: TRTTensor, - dim: Shape, + input: TRTTensor, + dim: int, ) -> TRTTensor: - input_val = get_trt_tensor(ctx, input_t, f"{name}_input_t") - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"unsqueeze received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dim = cast(int, dim) - - input_shape_size = len(input_val.shape) - dim = get_positive_dim(dim, input_shape_size + 1) - - intermediate_dim = 0 - dynamic_shape_cnt = 0 - # if unsqueeze the last dimensions, we can directly append to the shape - if dim == input_shape_size: - intermediate_dim = dim - else: - # since maximum of one dimension is permitted to be specified as -1 - # find the intermediate_dim which has only 1 dynamic_shape_cnt - # and then we can add a transpose after reshape if it is not the final shape we want - for i, s in reversed(list(enumerate(input_val.shape))): - if i >= dim: - if s == -1: - dynamic_shape_cnt += 1 - if dynamic_shape_cnt > 1: - intermediate_dim = i + 1 - break - if i == dim: - intermediate_dim = i - break - # calculate the new_shape for the shuffle layer's reshape_dims - new_shape = list( - tuple(input_val.shape)[:intermediate_dim] - + (1,) - + tuple(input_val.shape)[intermediate_dim:] - ) - for i, s in enumerate(new_shape): - if i < intermediate_dim and s == -1: - new_shape[i] = 0 - layer = ctx.net.add_shuffle(input_val) - layer.reshape_dims = tuple(new_shape) - # if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape - if intermediate_dim != dim: - # calculate the second_transpose for the shuffle layer - permutation = [*range(0, len(new_shape))] - # for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5) - # here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim) - new_permutation = ( - tuple(permutation[:dim]) - + (intermediate_dim,) - + tuple(permutation[dim:intermediate_dim]) - + tuple(permutation[intermediate_dim + 1 :]) - ) - layer.second_transpose = new_permutation + axes = get_trt_tensor(ctx, dim, f"{name}_axes") + layer = ctx.net.add_unsqueeze(input, axes) set_layer_name(layer, target, name, source_ir) return layer.get_output(0)