Skip to content

Commit 3c9b77f

Browse files
committed
removing the fuse distributed ops lowering pass for tegra platforms
1 parent 26ea41e commit 3c9b77f

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
from .repair_input_as_output import repair_input_as_output
1616
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1717

18-
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
19-
[
20-
remove_input_alias_fixing_clones,
21-
constant_fold,
22-
repair_input_as_output,
23-
fuse_prims_broadcast,
24-
fuse_distributed_ops,
25-
replace_max_pool_with_indices,
26-
remove_assert_nodes,
27-
accumulate_fp32_matmul,
28-
]
29-
)
18+
pass_list = [
19+
remove_input_alias_fixing_clones,
20+
constant_fold,
21+
repair_input_as_output,
22+
fuse_prims_broadcast,
23+
replace_max_pool_with_indices,
24+
lower_scaled_dot_product_attention,
25+
view_to_reshape,
26+
remove_assert_nodes,
27+
accumulate_fp32_matmul,
28+
]
29+
30+
if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
31+
pass_list.append(fuse_distributed_ops)
32+
33+
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)
3034

3135
ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
3236
[

0 commit comments

Comments
 (0)