Skip to content

Commit 26f33c7

Browse files
committed
utility function to detect tegra platform
1 parent 3c9b77f commit 26f33c7

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.utils import is_tegra_platform
67

78
from .accumulate_fp32_matmul import accumulate_fp32_matmul
89
from .constant_folding import constant_fold
@@ -27,7 +28,7 @@
2728
accumulate_fp32_matmul,
2829
]
2930

30-
if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
31+
if not is_tegra_platform():
3132
pass_list.append(fuse_distributed_ops)
3233

3334
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,3 +806,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
806806
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
807807
)
808808
return output_dtypes
809+
810+
811+
def is_tegra_platform() -> bool:
812+
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
813+
return True
814+
return False

0 commit comments

Comments
 (0)