Skip to content

Commit 5042c13

Browse files
committed
feat: support aten.as_strided converter
1 parent 7d30714 commit 5042c13

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,6 +2133,26 @@ def aten_ops_linear(
21332133
)
21342134

21352135

2136+
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
2137+
def aten_ops_as_strided(
2138+
ctx: ConversionContext,
2139+
target: Target,
2140+
args: Tuple[Argument, ...],
2141+
kwargs: Dict[str, Argument],
2142+
name: str,
2143+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2144+
return impl.slice.as_strided(
2145+
ctx,
2146+
target,
2147+
source_ir=SourceIR.ATEN,
2148+
name=name,
2149+
input=args[0],
2150+
size=args[1],
2151+
stride=args[2],
2152+
storage_offset=args_bounds_check(args, 3, 0),
2153+
)
2154+
2155+
21362156
def avg_pool_param_validator(pool_node: Node) -> bool:
21372157
ceil_mode = args_bounds_check(pool_node.args, 4, False)
21382158
divisor_override = args_bounds_check(pool_node.args, 6)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import numpy as np
55
import tensorrt as trt
6+
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
89
from torch_tensorrt.dynamo.conversion import impl
910
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1011
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
flatten_dims,
1113
get_positive_dim,
1214
get_trt_tensor,
1315
)
@@ -259,3 +261,59 @@ def flip(
259261
)
260262
set_layer_name(layer, target, name, source_ir)
261263
return layer.get_output(0)
264+
265+
266+
def as_strided(
267+
ctx: ConversionContext,
268+
target: Target,
269+
source_ir: Optional[SourceIR],
270+
name: str,
271+
input: TRTTensor,
272+
size: Sequence[int],
273+
stride: Sequence[int],
274+
storage_offset: int,
275+
) -> TRTTensor:
276+
assert len(size) == len(stride), "size and stride shapes must be the same"
277+
278+
flatten_shape = flatten_dims(input, 0, -1)
279+
flatten_output = impl.shuffle.reshape(
280+
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
281+
)
282+
283+
indices = []
284+
285+
# Recursive function to compute indices for as_strided operation
286+
def nested(
287+
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
288+
) -> None:
289+
if (
290+
dim == rank
291+
): # If the current dimension equals the rank, append the computed index
292+
indices.append(current)
293+
return
294+
for i in range(size[dim]): # Recursively compute indices across dimensions
295+
nested(
296+
rank, size, stride, current + stride[dim] * i, dim + 1
297+
) # Calculate the index for the current dimension and recursively explore further dimensions
298+
299+
nested(len(size), size, stride, storage_offset, 0)
300+
301+
indices = torch.tensor(indices, dtype=torch.int)
302+
303+
indices_tensor = get_trt_tensor(ctx, (indices), f"{name}_indices")
304+
305+
# Use gather to reorder elements based on computed indices
306+
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0)
307+
gather_output = gather_layer.get_output(0)
308+
309+
# Reshape the gathered tensor to the desired size
310+
reshape_output = impl.shuffle.reshape(
311+
ctx,
312+
target,
313+
source_ir,
314+
f"{name}_reshape",
315+
gather_output,
316+
tuple(size),
317+
)
318+
319+
return reshape_output
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAsStridedConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
(5, 5),
14+
(2, 3),
15+
(1, 2),
16+
0,
17+
),
18+
(
19+
(5, 5),
20+
(2, 3),
21+
(2, 2),
22+
1,
23+
),
24+
(
25+
(20, 20),
26+
(2, 3, 2),
27+
(2, 2, 2),
28+
0,
29+
),
30+
(
31+
(8, 8, 8),
32+
(2, 2, 3),
33+
(1, 2, 2),
34+
1,
35+
),
36+
(
37+
(200, 200, 200),
38+
(9, 9, 3, 2),
39+
(2, 2, 2, 3),
40+
1,
41+
),
42+
]
43+
)
44+
def test_as_strided(
45+
self,
46+
input_shape,
47+
output_size,
48+
stride,
49+
storage_offset=0,
50+
):
51+
class TestModule(torch.nn.Module):
52+
def forward(self, x):
53+
return torch.ops.aten.as_strided.default(
54+
x, output_size, stride, storage_offset
55+
)
56+
57+
inputs = [torch.randn(input_shape)]
58+
self.run_test(
59+
TestModule(),
60+
inputs,
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
run_tests()

0 commit comments

Comments
 (0)