Skip to content

feat: Add dynamic shapes support for torch.compile workflow #2627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 63 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
d06c74a
chore: Switch to new export apis
peri044 Oct 3, 2023
47e0997
chore: rebase with main
peri044 Oct 9, 2023
fd29fe0
Merge branch 'main' into export_2.2
peri044 Oct 11, 2023
ad3b031
feat: Add support for dynamic shapes and remove constraints API
peri044 Oct 19, 2023
1582b72
chore: add dynamic shape support for certain converters
peri044 Oct 23, 2023
4d01545
chore: minor updates
peri044 Oct 25, 2023
6731a57
chore: updates
peri044 Oct 26, 2023
a8a194b
chore: rebase with main
peri044 Nov 15, 2023
0b60aae
chore: add sym int converter
peri044 Nov 15, 2023
634612f
feat: Replace the existing shape propagation with symbolic shape prop…
peri044 Nov 16, 2023
93edba4
chore: fix imports
peri044 Nov 16, 2023
7ad9272
chore: fix imports
peri044 Nov 16, 2023
f444d54
chore: updates
peri044 Nov 21, 2023
6e5c582
chore: change device calls
peri044 Nov 28, 2023
83791f8
chore: fix metadata check
peri044 Dec 5, 2023
8375996
chore: rebase with main
peri044 Dec 15, 2023
aba91fa
Merge branch 'main' into dyn_2.2
peri044 Dec 22, 2023
16394d9
chore: minor fixes
peri044 Jan 7, 2024
b9a7ccd
chore: Add sym_size converter tests
peri044 Jan 8, 2024
15cc643
chore: Update test utilities
peri044 Jan 8, 2024
5234d74
chore: add testcase for sym_size.int
peri044 Jan 8, 2024
fd2dae1
Merge branch 'main' into dyn_2.2
peri044 Jan 26, 2024
51e8bb7
chore: revert output type change
peri044 Jan 26, 2024
19c3fad
chore: add update_metadata utility
peri044 Jan 27, 2024
ed48551
chore: change debug to warning if the graph does not have metadata
peri044 Jan 27, 2024
18b7e11
feat: add lowering passes to support dynamic shapes for torch.compile
peri044 Jan 30, 2024
3a39d27
chore: add test case
peri044 Jan 30, 2024
abb2677
chore: add view test case
peri044 Feb 2, 2024
9aff04b
chore: gpt2 changes + linting
peri044 Feb 7, 2024
440fcd5
chore: gpt2 changes + linting
peri044 Feb 7, 2024
a2d38f3
chore: rebase with main
peri044 Feb 7, 2024
002db3c
chore: add fallback option if val is missing in metadata
peri044 Feb 7, 2024
00cd17b
chore: tmp changes
peri044 Feb 13, 2024
6ac70cd
chore: tmp changes
peri044 Feb 13, 2024
b827070
Merge branch 'main' into dyn_2.2
peri044 Feb 16, 2024
8f9bca0
Merge branch 'main' into dyn_2.2
peri044 Feb 21, 2024
4399d57
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Feb 21, 2024
39615a2
chore: fixes
peri044 Feb 26, 2024
cd86660
feat: Add save API for torch-trt compiled models
peri044 Mar 14, 2024
3ece71b
chore: resolve merge conflicts
peri044 Mar 15, 2024
1fa1771
Merge branch 'main' into dyn_2.2
peri044 Mar 15, 2024
febf05b
Merge branch 'save' into dyn_2.2
peri044 Mar 15, 2024
eab0dba
chore: Fix save failures
peri044 Mar 18, 2024
b191d62
chore: update to 2.3 rc build
peri044 Mar 18, 2024
380477b
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Mar 19, 2024
5f34d4f
chore: minor fixes
peri044 Mar 19, 2024
ce606fe
chore: rebase with release/2.3 branch
peri044 Mar 19, 2024
8674a3c
chore: minor fixes
peri044 Mar 19, 2024
f4e8fe9
chore: remove duplicate bert test case
peri044 Mar 20, 2024
4ae6ab9
chore: remove comments
peri044 Mar 20, 2024
c14f28d
Merge branch 'save' into dyn_2.2
peri044 Mar 20, 2024
3295c02
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Mar 20, 2024
4188173
chore: rebase with release/2.3
peri044 Apr 2, 2024
f6b758e
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 2, 2024
78f7eb5
chore: updates
peri044 Apr 2, 2024
fe13c2a
chore: Update mypy type for sample_inputs
peri044 Apr 2, 2024
e9b649d
chore: revert changes
peri044 Apr 5, 2024
03ecc61
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 5, 2024
978c039
Merge branch 'release/2.3' into dyn_2.2
peri044 Apr 5, 2024
ccb88c8
Merge branch 'dyn_2.2' into dyn_2.2_tc
peri044 Apr 5, 2024
3cccf8a
chore: rebase
peri044 Apr 15, 2024
2d24686
chore: update to use test channel
peri044 Apr 15, 2024
8e36525
chore: updates
peri044 Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
os: linux
test-infra-repository: pytorch/test-infra
test-infra-ref: main
channel: test
with-rocm: false
with-cpu: false

Expand Down Expand Up @@ -197,6 +198,7 @@ jobs:
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
popd

tests-py-dynamo-core:
Expand Down
12 changes: 5 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
# Check if the module has metadata (shape, dtype).
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
logger.warning(
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
Expand Down
20 changes: 14 additions & 6 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
remove_sym_nodes,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
Expand All @@ -27,7 +28,7 @@
@td.register_backend(name="tensorrt") # type: ignore[misc]
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
def torch_tensorrt_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
# Set log level at the top of compilation (torch_tensorrt.dynamo)
if (
Expand All @@ -44,15 +45,15 @@ def torch_tensorrt_backend(

@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings)


def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule | Callable[..., Any]:
"""Helper function to manage translation of traced FX module to TRT engines
Expand All @@ -74,10 +75,17 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm)
torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
sample_inputs,
torch_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
Expand All @@ -86,10 +94,10 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = apply_lowering_passes(gm, sample_inputs)
gm = apply_lowering_passes(gm, torch_inputs)

torchtrt_inputs = prepare_inputs(
sample_inputs, disable_memory_format_check=True
torch_inputs, disable_memory_format_check=True
)
trt_compiled = compile_module(
gm,
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
Expand Down Expand Up @@ -87,8 +86,9 @@ def get_shape_with_dynamic_shape(
scale_res = scale_layer.get_output(0)

length = input_shape.shape[0]

zero_layer = ctx.net.add_constant(
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
input_shape.shape, np.zeros((length), dtype=np.int32)
)
set_layer_name(zero_layer, target, f"{name}_zeros")

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
torch_enabled_decompositions,
)
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F401
from ._remove_sym_nodes import remove_sym_nodes
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
82 changes: 0 additions & 82 deletions py/torch_tensorrt/dynamo/lowering/_fusers.py

This file was deleted.

30 changes: 30 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

import torch

logger = logging.getLogger(__name__)


def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
# Extract SymInt placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.SymInt)
)
]

for node in placeholders:
gm.graph.erase_node(node)

gm.graph.lint()
gm.recompile()
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")

return gm
Loading