Skip to content

feat: Add decorator utility to improve error messaging for legacy support #1738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch._dynamo as torchdynamo

from torch.fx.passes.infra.pass_base import PassResult

from torch_tensorrt.fx.utils import req_torch_version
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
compose_bmm,
compose_chunk,
Expand Down Expand Up @@ -91,6 +91,7 @@ def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None,
sys.setrecursionlimit(default)


@req_torch_version("2.0")
def dynamo_trace(
f: Callable[..., Value],
# pyre-ignore
Expand All @@ -104,11 +105,6 @@ def dynamo_trace(
this config option alltogether. For now, it helps with quick
experiments with playing around with TorchDynamo
"""
if torch.__version__.startswith("1"):
raise ValueError(
f"The aten tracer requires Torch version >= 2.0. Detected version {torch.__version__}"
)

if dynamo_config is None:
dynamo_config = DynamoConfig()
with using_config(dynamo_config), setting_python_recursive_limit(2000):
Expand All @@ -131,11 +127,13 @@ def dynamo_trace(
) from exc


@req_torch_version("2.0")
def trace(f, args, *rest):
graph_module, guards = dynamo_trace(f, args, True, "symbolic")
return graph_module, guards


@req_torch_version("2.0")
def opt_trace(f, args, *rest):
"""
Optimized trace with necessary passes which re-compose some ops or replace some ops
Expand Down
36 changes: 35 additions & 1 deletion py/torch_tensorrt/fx/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import List
from typing import List, Callable
from packaging import version

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
Expand Down Expand Up @@ -104,3 +105,36 @@ def f(*inp):
mod = run_const_fold(mod)
mod = replace_op_with_indices(mod)
return mod


def req_torch_version(min_torch_version: str = "2.dev"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default is "2.dev" since "2.0" is considered to be a larger version than "2.0.0.dev____".

"""
Create a decorator which verifies the Torch version installed
against a specified version range

Args:
min_torch_version (str): The minimum required Torch version
for the decorated function to work properly

Returns:
A decorator which raises a descriptive error message if
an unsupported Torch version is used
"""

def nested_decorator(f: Callable):
def function_wrapper(*args, **kwargs):
# Parse minimum and current Torch versions
min_version = version.parse(min_torch_version)
current_version = version.parse(torch.__version__)

if current_version < min_version:
raise AssertionError(
f"Expected Torch version {min_torch_version} or greater, "
+ f"when calling {f}. Detected version {torch.__version__}"
)
Comment on lines +131 to +134
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function f is printed in the error message, which includes its name.

else:
return f(*args, **kwargs)

return function_wrapper

return nested_decorator