diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 70c4574b94..7fa423bdfb 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -622,6 +622,30 @@ def aten_ops_slice( ) +@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_chunk( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.chunk( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, 0), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 97ffdb728f..25d0052581 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -109,3 +109,58 @@ def expand( layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def chunk( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + chunks: int, + dim: int, +) -> TRTTensor: + if chunks <= 0: + raise RuntimeError( + f"chunk expects `chunks` to be greater than 0, got: {chunks}" + ) + + shape = input.shape + dim = get_positive_dim(dim, len(shape)) + + if dim >= len(shape): + raise RuntimeError( + f"chunk expects `dim` to be less than the length of input shape, got: {dim}" + ) + + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + size_dim = shape[dim] + chunk_size = math.ceil(size_dim / chunks) + result = [] + start = 0 + end = min(start + chunk_size, size_dim) + cnt = 0 + + while start < end: + result.append( + slice_op( + ctx, + target, + source_ir, + f"{name}_slice_{cnt}", + input, + dim, + start, + end, + 1, + ) + ) + start = end + end = min(start + chunk_size, size_dim) + cnt += 1 + + return result diff --git a/tests/py/dynamo/conversion/test_chunk_aten.py b/tests/py/dynamo/conversion/test_chunk_aten.py new file mode 100644 index 0000000000..1812165b43 --- /dev/null +++ b/tests/py/dynamo/conversion/test_chunk_aten.py @@ -0,0 +1,82 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestChunkConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1,), 3, 0), + ((3,), 3, 0), + ((4,), 3, 0), + ((6,), 3, 0), + ((3,), 1, -1), + ((3,), 3, -1), + ((3,), 4, -1), + ] + ) + def test_chunk_1D(self, shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input = [torch.randn(shape)] + self.run_test( + TestChunk(), + input, + ) + + @parameterized.expand( + [ + ((3, 4), 1, 0), + ((3, 4), 3, 0), + ((3, 4), 4, 0), + ((3, 4), 2, -2), + ((3, 4), 6, -2), + ((3, 4), 3, 1), + ((3, 4), 4, 1), + ((3, 4), 5, -1), + ] + ) + def test_chunk_2D(self, shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input = [torch.randn(shape)] + self.run_test( + TestChunk(), + input, + ) + + @parameterized.expand( + [ + ((3, 4, 2), 1, 0), + ((3, 4, 2), 3, -3), + ((3, 4, 2), 3, 1), + ((3, 4, 2), 4, 1), + ((3, 4, 2), 6, -2), + ((3, 4, 2), 1, 2), + ((3, 4, 2), 3, -1), + ((3, 4, 2), 4, -1), + ] + ) + def test_chunk_3D(self, shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input = [torch.randn(shape)] + self.run_test( + TestChunk(), + input, + ) + + +if __name__ == "__main__": + run_tests()