Skip to content

feat: cherry-pick of Implement symbolic shape propagation, sym_size converter #2751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def compile(
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph: " + str(gm.graph))

compilation_options = {
Expand Down Expand Up @@ -264,6 +264,24 @@ def compile_module(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

def contains_metadata(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "output" and (not node.meta) and "val" not in node.meta:
logger.warning(
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
)
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

Expand Down Expand Up @@ -322,12 +340,7 @@ def compile_module(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
submodule,
sample_inputs,
to_torch_device(settings.device),
)
submodule_inputs = partitioning.construct_submodule_inputs(submodule)

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
Expand Down
16 changes: 13 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,22 @@ def infer_module_output_dtypes(
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
if truncate_long_and_double and output.dtype == dtype.float64:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Receieved an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_long_and_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
elif truncate_long_and_double and output.dtype == dtype.int64:
elif truncate_long_and_double and output_.dtype == dtype.int64:
output_dtypes.append(dtype.int32)
else:
output_dtypes.append(dtype._from(output.dtype))
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes

Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,22 @@ def aten_ops_sigmoid(
)


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
def aten_ops_symsize_int(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
index = node.args[1]
for ind in index:
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt

# nearest, linear, cubic
GridSamplerInterpolationMode = {
0: trt.InterpolationMode.NEAREST,
Expand Down
20 changes: 18 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand All @@ -17,7 +17,23 @@ def reshape(
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
if isinstance(s, TRTTensor):
trt_shape.append(s)
else:
a = get_trt_tensor(ctx, s, f"{name}_{i}")
trt_shape.append(a)
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
shape_layer.axis = 0
shape_layer.name = f"{name}_output_shape"
layer.set_input(1, shape_layer.get_output(0))

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def expand(
) -> TRTTensor:
shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
Expand Down Expand Up @@ -99,6 +98,7 @@ def expand(
stride = tuple(
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
) # stride == 1 if dimensions match, 0 otherwise

layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
Expand Down
23 changes: 22 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List

import torch

Expand Down Expand Up @@ -29,3 +29,24 @@ def get_tensor_placeholders(
]

return placeholders


def get_metadata(
gm: torch.fx.GraphModule, target_op: Any
) -> List[torch._ops.OpOverload]:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
return [node.meta for node in gm.graph.nodes if node.target == target_op]


def set_metadata(
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
) -> None:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
assert len(target_nodes) == len(metadata)
for idx, node in enumerate(target_nodes):
node.meta = metadata[idx]
36 changes: 18 additions & 18 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Callable, List, Sequence, Tuple
from typing import List, Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_metadata,
set_metadata,
)

logger = logging.getLogger(__name__)
Expand All @@ -13,27 +15,25 @@ def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]:
"""Constructs the original and replacement functions for view"""
orig_op = torch.ops.aten.view.default
replacement_op = torch.ops.aten.reshape.default

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
return orig_op(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)
return replacement_op(input, shape)

return orig, replacement
# Store metadata of the orig_op
metadata = get_metadata(gm, orig_op)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

# Copy the orig_op's metadata to the replacement op
set_metadata(gm, replacement_op, metadata)

return gm
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
run_shape_analysis,
)
Loading
Loading