Skip to content

Commit e6e5170

Browse files
bowang007gs-olive
andcommitted
support argmax converter (#2291)
Signed-off-by: Bo Wang <[email protected]> Co-authored-by: gs-olive <[email protected]>
1 parent 738771a commit e6e5170

File tree

4 files changed

+137
-1
lines changed

4 files changed

+137
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def aten_ops_group_norm(
175175
)
176176

177177

178-
@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
178+
@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc]
179179
def aten_ops_cat(
180180
ctx: ConversionContext,
181181
target: Target,
@@ -1797,3 +1797,23 @@ def aten_ops_reshape(
17971797
input=args[0],
17981798
shape=args[1],
17991799
)
1800+
1801+
1802+
@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc]
1803+
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
1804+
def aten_ops_argmax(
1805+
ctx: ConversionContext,
1806+
target: Target,
1807+
args: Tuple[Argument, ...],
1808+
kwargs: Dict[str, Argument],
1809+
name: str,
1810+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1811+
return impl.argmax.argmax(
1812+
ctx,
1813+
target,
1814+
SourceIR.ATEN,
1815+
name,
1816+
input=args[0],
1817+
dim=args_bounds_check(args, 1),
1818+
keep_dim=args_bounds_check(args, 2, False),
1819+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import (
44
activation,
55
attention,
6+
argmax,
67
cast,
78
cat,
89
condition,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Optional
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion import impl
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
cast_trt_tensor,
10+
flatten_dims,
11+
get_axes_for_reduce_op,
12+
get_positive_dim,
13+
)
14+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
15+
from torch_tensorrt.fx.types import TRTTensor
16+
17+
18+
def argmax(
19+
ctx: ConversionContext,
20+
target: Target,
21+
source_ir: Optional[SourceIR],
22+
name: str,
23+
input: TRTTensor,
24+
dim: Optional[int],
25+
keep_dim: bool = False,
26+
) -> TRTTensor:
27+
if input.dtype == trt.int32:
28+
input = cast_trt_tensor(ctx, input, trt.float32, name, target, source_ir)
29+
30+
# Three different cases here:
31+
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
32+
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
33+
# 3. normal cases, no additional handlings
34+
out = input
35+
36+
if dim is None:
37+
new_shape = (*flatten_dims(input, 0, -1), 1)
38+
out = impl.shuffle.reshape(
39+
ctx, target, source_ir, f"{name}_flatten", input, new_shape
40+
)
41+
elif len(input.shape) == 1:
42+
new_shape = (*input.shape, 1)
43+
out = impl.shuffle.reshape(
44+
ctx, target, source_ir, f"{name}_broadcast", input, new_shape
45+
)
46+
47+
# Reduce over the flattened input if the dimension is None, otherwise the specified dimension
48+
reduce_mask = get_axes_for_reduce_op(
49+
get_positive_dim(dim if dim is not None else 0, len(out.shape))
50+
)
51+
52+
topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
53+
set_layer_name(topk_layer, target, name, source_ir)
54+
55+
out = topk_layer.get_output(1)
56+
57+
if dim is None:
58+
new_shape = ((1,) * len(input.shape)) if keep_dim else ()
59+
out = impl.shuffle.reshape(
60+
ctx, target, source_ir, f"{name}_unflatten", out, new_shape
61+
)
62+
elif len(input.shape) == 1:
63+
out = impl.squeeze.squeeze(
64+
ctx,
65+
target,
66+
source_ir,
67+
f"{name}_squeeze",
68+
out,
69+
1 if keep_dim else (0, 1),
70+
)
71+
elif not keep_dim:
72+
out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim)
73+
74+
return out
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArgmaxConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
# input dimension == 1
13+
("dim_1_keep_dim_true", (3,), 0, True),
14+
("dim_1_keep_dim_true", (3,), 0, False),
15+
# dim == None
16+
("dim_none", (3,), None, True),
17+
("dim_none", (3, 3), None, True),
18+
("dim_none", (3, 3, 3), None, False),
19+
# # common cases
20+
("dim_1_keep_dim_true", (3, 3), 1, True),
21+
("dim_1_keep_dim_false", (3, 3), 1, False),
22+
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
23+
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
24+
("dim_negative_keep_dim_true", (1, 2, 3), -1, True),
25+
]
26+
)
27+
def test_argmax(self, _, input_shape, dim, keep_dim):
28+
class ArgMax(nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
32+
def forward(self, input):
33+
return torch.ops.aten.argmax.default(input, dim, keep_dim)
34+
35+
input = [torch.randn(*input_shape)]
36+
37+
self.run_test(ArgMax(), input)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)