diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 2a252bd965..6c12b80f8b 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -9,7 +9,7 @@ import torch._dynamo as torchdynamo from torch.fx.passes.infra.pass_base import PassResult - +from torch_tensorrt.fx.utils import req_torch_version from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, @@ -91,6 +91,7 @@ def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, sys.setrecursionlimit(default) +@req_torch_version("2.0") def dynamo_trace( f: Callable[..., Value], # pyre-ignore @@ -104,11 +105,6 @@ def dynamo_trace( this config option alltogether. For now, it helps with quick experiments with playing around with TorchDynamo """ - if torch.__version__.startswith("1"): - raise ValueError( - f"The aten tracer requires Torch version >= 2.0. Detected version {torch.__version__}" - ) - if dynamo_config is None: dynamo_config = DynamoConfig() with using_config(dynamo_config), setting_python_recursive_limit(2000): @@ -131,11 +127,13 @@ def dynamo_trace( ) from exc +@req_torch_version("2.0") def trace(f, args, *rest): graph_module, guards = dynamo_trace(f, args, True, "symbolic") return graph_module, guards +@req_torch_version("2.0") def opt_trace(f, args, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 1055621ce5..79779f604e 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import List +from typing import List, Callable +from packaging import version # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt @@ -104,3 +105,36 @@ def f(*inp): mod = run_const_fold(mod) mod = replace_op_with_indices(mod) return mod + + +def req_torch_version(min_torch_version: str = "2.dev"): + """ + Create a decorator which verifies the Torch version installed + against a specified version range + + Args: + min_torch_version (str): The minimum required Torch version + for the decorated function to work properly + + Returns: + A decorator which raises a descriptive error message if + an unsupported Torch version is used + """ + + def nested_decorator(f: Callable): + def function_wrapper(*args, **kwargs): + # Parse minimum and current Torch versions + min_version = version.parse(min_torch_version) + current_version = version.parse(torch.__version__) + + if current_version < min_version: + raise AssertionError( + f"Expected Torch version {min_torch_version} or greater, " + + f"when calling {f}. Detected version {torch.__version__}" + ) + else: + return f(*args, **kwargs) + + return function_wrapper + + return nested_decorator