Skip to content

feat: support chunk dynamo converter #2401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,30 @@ def aten_ops_slice(
)


@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_chunk(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.chunk(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 2, 0),
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@enforce_tensor_types(
{
Expand Down
55 changes: 55 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,58 @@ def expand(
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def chunk(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
chunks: int,
dim: int,
) -> TRTTensor:
if chunks <= 0:
raise RuntimeError(
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
)

shape = input.shape
dim = get_positive_dim(dim, len(shape))

if dim >= len(shape):
raise RuntimeError(
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
)

dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape > 0:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"

size_dim = shape[dim]
chunk_size = math.ceil(size_dim / chunks)
result = []
start = 0
end = min(start + chunk_size, size_dim)
cnt = 0

while start < end:
result.append(
slice_op(
ctx,
target,
source_ir,
f"{name}_slice_{cnt}",
input,
dim,
start,
end,
1,
)
)
start = end
end = min(start + chunk_size, size_dim)
cnt += 1

return result
82 changes: 82 additions & 0 deletions tests/py/dynamo/conversion/test_chunk_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestChunkConverter(DispatchTestCase):
@parameterized.expand(
[
((1,), 3, 0),
((3,), 3, 0),
((4,), 3, 0),
((6,), 3, 0),
((3,), 1, -1),
((3,), 3, -1),
((3,), 4, -1),
]
)
def test_chunk_1D(self, shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input = [torch.randn(shape)]
self.run_test(
TestChunk(),
input,
)

@parameterized.expand(
[
((3, 4), 1, 0),
((3, 4), 3, 0),
((3, 4), 4, 0),
((3, 4), 2, -2),
((3, 4), 6, -2),
((3, 4), 3, 1),
((3, 4), 4, 1),
((3, 4), 5, -1),
]
)
def test_chunk_2D(self, shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input = [torch.randn(shape)]
self.run_test(
TestChunk(),
input,
)

@parameterized.expand(
[
((3, 4, 2), 1, 0),
((3, 4, 2), 3, -3),
((3, 4, 2), 3, 1),
((3, 4, 2), 4, 1),
((3, 4, 2), 6, -2),
((3, 4, 2), 1, 2),
((3, 4, 2), 3, -1),
((3, 4, 2), 4, -1),
]
)
def test_chunk_3D(self, shape, chunks, dim):
class TestChunk(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.chunk.default(input, chunks, dim)
return out

input = [torch.randn(shape)]
self.run_test(
TestChunk(),
input,
)


if __name__ == "__main__":
run_tests()