Skip to content

Commit 0f95b96

Browse files
authored
removing the fuse distributed ops lowering pass for tegra platforms (#3411)
1 parent be652e9 commit 0f95b96

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.utils import is_tegra_platform
67

78
from .accumulate_fp32_matmul import accumulate_fp32_matmul
89
from .constant_folding import constant_fold
@@ -15,18 +16,20 @@
1516
from .repair_input_as_output import repair_input_as_output
1617
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1718

18-
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
19-
[
20-
remove_input_alias_fixing_clones,
21-
constant_fold,
22-
repair_input_as_output,
23-
fuse_prims_broadcast,
24-
fuse_distributed_ops,
25-
replace_max_pool_with_indices,
26-
remove_assert_nodes,
27-
accumulate_fp32_matmul,
28-
]
29-
)
19+
pass_list = [
20+
remove_input_alias_fixing_clones,
21+
constant_fold,
22+
repair_input_as_output,
23+
fuse_prims_broadcast,
24+
replace_max_pool_with_indices,
25+
remove_assert_nodes,
26+
accumulate_fp32_matmul,
27+
]
28+
29+
if not is_tegra_platform():
30+
pass_list.append(fuse_distributed_ops)
31+
32+
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)
3033

3134
ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
3235
[

py/torch_tensorrt/dynamo/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,3 +806,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
806806
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
807807
)
808808
return output_dtypes
809+
810+
811+
def is_tegra_platform() -> bool:
812+
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
813+
return True
814+
return False

0 commit comments

Comments
 (0)