diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 676e6e1175..b66f36c11e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -3,6 +3,7 @@ import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold @@ -15,18 +16,20 @@ from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_input_alias_fixing_clones, - constant_fold, - repair_input_as_output, - fuse_prims_broadcast, - fuse_distributed_ops, - replace_max_pool_with_indices, - remove_assert_nodes, - accumulate_fp32_matmul, - ] -) +pass_list = [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + fuse_prims_broadcast, + replace_max_pool_with_indices, + remove_assert_nodes, + accumulate_fp32_matmul, +] + +if not is_tegra_platform(): + pass_list.append(fuse_distributed_ops) + +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 557c01667f..e4018ae95c 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -806,3 +806,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False