diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bf07cc73e7..5bcac8a237 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -163,10 +163,10 @@ def compile( ) gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) - # Apply lowering on the graph module torch_inputs = get_torch_inputs(inputs, device) gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) compilation_options = { @@ -264,6 +264,24 @@ def compile_module( f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph." ) + def contains_metadata(gm: torch.fx.GraphModule) -> bool: + for node in gm.graph.nodes: + if node.op != "output" and (not node.meta) and "val" not in node.meta: + logger.warning( + f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior." + ) + return False + return True + + # Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation. + if not contains_metadata(gm): + from torch._inductor.compile_fx import fake_tensor_prop + + torch_inputs = get_torch_inputs(sample_inputs, settings.device) + with torch.no_grad(): + # This fails if the module has data-dependent shape operators. + fake_tensor_prop(gm, torch_inputs) + # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False @@ -322,12 +340,7 @@ def compile_module( ) # Get the submodule inputs for min, opt, max shapes of the graph inputs - submodule_inputs = partitioning.get_submod_inputs( - partitioned_module, - submodule, - sample_inputs, - to_torch_device(settings.device), - ) + submodule_inputs = partitioning.construct_submodule_inputs(submodule) logger.debug( "Submodule name: %s\n Input shapes: %s\n %s", diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 1fa2806181..bade91c553 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -74,7 +74,6 @@ def _pretraced_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: repair_input_aliasing(gm) - # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 373c128920..93fc73b4e2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -38,12 +38,22 @@ def infer_module_output_dtypes( # such as aten.sum - such outputs can be truncated output_dtypes = [] for output in module_outputs: - if truncate_long_and_double and output.dtype == dtype.float64: + output_ = output + # We don't need to check if output is nested here because the input module will be flattened + if not isinstance(output, torch.Tensor): + if isinstance(output, str): + raise ValueError( + f"Receieved an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" + ) + else: + output_ = torch.tensor(output) + + if truncate_long_and_double and output_.dtype == dtype.float64: output_dtypes.append(dtype.float32) - elif truncate_long_and_double and output.dtype == dtype.int64: + elif truncate_long_and_double and output_.dtype == dtype.int64: output_dtypes.append(dtype.int32) else: - output_dtypes.append(dtype._from(output.dtype)) + output_dtypes.append(dtype._from(output_.dtype)) return output_dtypes diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 08827d8fa8..c566d9de0a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -394,6 +394,22 @@ def aten_ops_sigmoid( ) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int) +def aten_ops_symsize_int( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) + + def index_dtype_validator(node: Node) -> bool: index = node.args[1] for ind in index: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py index 8981eca73c..63ff93b0c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/grid.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -1,13 +1,12 @@ from typing import Optional +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor -import tensorrt as trt - # nearest, linear, cubic GridSamplerInterpolationMode = { 0: trt.InterpolationMode.NEAREST, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 1d6dd7396f..b2a79af5cb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -3,7 +3,7 @@ import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -17,7 +17,23 @@ def reshape( shape: Sequence[int], ) -> TRTTensor: layer = ctx.net.add_shuffle(input) - layer.reshape_dims = tuple(shape) + if all(isinstance(s, int) for s in shape): + layer.reshape_dims = tuple(shape) + else: + # Convert all the dimensions to trt Tensors. + trt_shape = [] + + for i, s in enumerate(shape): + if isinstance(s, TRTTensor): + trt_shape.append(s) + else: + a = get_trt_tensor(ctx, s, f"{name}_{i}") + trt_shape.append(a) + shape_layer = ctx.net.add_concatenation(inputs=trt_shape) + shape_layer.axis = 0 + shape_layer.name = f"{name}_output_shape" + layer.set_input(1, shape_layer.get_output(0)) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 5f1db00f33..61d71fe9a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -69,7 +69,6 @@ def expand( ) -> TRTTensor: shape_rank = len(shape) initial_tensor_rank = len(input_t.shape) - # If the rank of the input tensor is less than the shape's rank, pad with ones if initial_tensor_rank < shape_rank: input_t = prepend_ones( @@ -99,6 +98,7 @@ def expand( stride = tuple( [int(i == o) for i, o in zip(input_tensor_shape, shape)] ) # stride == 1 if dimensions match, 0 otherwise + layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py index 31a55099c2..0ffc6d3c76 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List import torch @@ -29,3 +29,24 @@ def get_tensor_placeholders( ] return placeholders + + +def get_metadata( + gm: torch.fx.GraphModule, target_op: Any +) -> List[torch._ops.OpOverload]: + """ + Return the list which has the metadata of all the target_op nodes present in the graph. + """ + return [node.meta for node in gm.graph.nodes if node.target == target_op] + + +def set_metadata( + gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload] +) -> None: + """ + Return the list which has the metadata of all the target_op nodes present in the graph. + """ + target_nodes = [node for node in gm.graph.nodes if node.target == target_op] + assert len(target_nodes) == len(metadata) + for idx, node in enumerate(target_nodes): + node.meta = metadata[idx] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py index e2ef051f06..b2da354122 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -1,9 +1,11 @@ import logging -from typing import Callable, List, Sequence, Tuple +from typing import List, Sequence import torch from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, + get_metadata, + set_metadata, ) logger = logging.getLogger(__name__) @@ -13,27 +15,25 @@ def view_to_reshape( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: """Replace aten.view with an equivalent implementation which avoids Tensor memory issues""" - orig, replacement = view_replacement() - - if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): - gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") - - return gm - - -def view_replacement() -> Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], -]: - """Constructs the original and replacement functions for view""" + orig_op = torch.ops.aten.view.default + replacement_op = torch.ops.aten.reshape.default # Original graph def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: - return torch.ops.aten.view.default(input, shape) + return orig_op(input, shape) # Replacement graph def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: - return torch.ops.aten.reshape.default(input, shape) + return replacement_op(input, shape) - return orig, replacement + # Store metadata of the orig_op + metadata = get_metadata(gm, orig_op) + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") + + # Copy the orig_op's metadata to the replacement op + set_metadata(gm, replacement_op, metadata) + + return gm diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py index 1a8cc94099..25487da065 100644 --- a/py/torch_tensorrt/dynamo/partitioning/__init__.py +++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py @@ -1,3 +1,7 @@ from ._adjacency_partitioner import partition as fast_partition from ._global_partitioner import partition as global_partition -from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis +from .common import ( + construct_submodule_inputs, + get_graph_converter_support, + run_shape_analysis, +) diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 8348738afa..270973c8c3 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -4,11 +4,99 @@ import torch from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG -from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic logger = logging.getLogger(__name__) +def contains_sym_int(tensor: torch.Tensor) -> bool: + """ + Returns true if the given tensor has symbolic shape. + """ + for dim in tensor: + if isinstance(dim, torch.SymInt): + return True + return False + + +def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input: + """ + Constructs a torch_tensorrt.Input based on a symbolic input + Args: + input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values) + Returns: + A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input. + """ + min_shape = [] + opt_shape = [] + max_shape = [] + for dim in input_shape: + if isinstance(dim, torch.SymInt): + node = dim.node + expr = node.expr + shape_env = node.shape_env + var_range = shape_env.var_to_range.get(expr, None) + var_val = shape_env.var_to_val.get(expr, None) + assert var_range, var_val + # Torchdynamo 0/1 specialization outlier + if var_range.lower == 2: + min_shape.append(1) + else: + min_shape.append(int(var_range.lower)) + opt_shape.append(int(var_val)) + max_shape.append(int(var_range.upper)) + else: + min_shape.append(dim) + opt_shape.append(dim) + max_shape.append(dim) + + return Input( + min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype + ) + + +def get_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input: + """ + Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs + """ + if contains_sym_int(input_shape): + return construct_dynamic_input(input_shape, input_dtype) + else: + return Input(shape=input_shape, dtype=input_dtype) + + +def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: + """ + Construct torch_tensorrt Inputs based on the module inputs. + The module inputs will have meta data which has the shape and dtype info + Args: + module: Input FX GraphModule + Returns: + Sequence of torch_tensorrt.Input's representing inputs to given module + """ + torchtrt_inputs = [] + module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"] + for input in module_inputs: + if input.meta: + if "val" in input.meta: + input_meta = input.meta["val"] + input_shape = input_meta.size() + torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + elif "tensor_meta" in input.meta: + input_meta = input.meta["tensor_meta"] + input_shape = input_meta.shape + torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + else: + raise AssertionError( + f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly" + ) + else: + raise AssertionError( + f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly" + ) + + return torchtrt_inputs + + def run_shape_analysis( parent_module: torch.fx.GraphModule, inputs: Sequence[Input] ) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]: @@ -46,80 +134,6 @@ def get_submodule_io( return submod_inputs_shape_map, submod_outputs_shape_map -def get_submod_inputs( - mod: torch.fx.GraphModule, - submod: torch.fx.GraphModule, - inputs: Sequence[Input], - device: torch.device, -) -> Optional[Sequence[torch.Tensor]]: - """Helper function to get inputs to a Torch submodule - - Args: - mod: Parent FX GraphModule - submod: Child FX GraphModule - inputs: Sample inputs to parent module - Returns: - Sequence of Tensors representing inputs to child module - """ - acc_inputs: Any = None - - def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None: - nonlocal acc_inputs - acc_inputs = inputs - return - - # Register a hook to capture submodule input - handle = submod.register_forward_pre_hook(get_input) - # Iterate over min, opt, max shapes for dynamic inputs - inputs_map = {} - - if input_is_dynamic(inputs): - for mode in ["min_shape", "opt_shape", "max_shape"]: - torch_inputs = get_torch_inputs(inputs, device, mode) - mod(*torch_inputs) - inputs_map[mode] = acc_inputs - handle.remove() - else: - torch_inputs = get_torch_inputs(inputs, device) - mod(*torch_inputs) - handle.remove() - assert isinstance(acc_inputs, tuple) - return [ - Input(shape=acc_input.shape, dtype=acc_input.dtype) - for acc_input in acc_inputs - ] - - num_submodule_inputs = ( - len(inputs_map["min_shape"]) if inputs_map["min_shape"] else 0 - ) - submodule_inputs = [] - for idx in range(num_submodule_inputs): - if not isinstance(inputs_map["min_shape"][idx], torch.Tensor): - input_val = torch.tensor(inputs_map["opt_shape"][idx], dtype=torch.int32) - logger.warning( - "Detected a zero-dimensional input. This might be a shape tensor input which is not currently supported. This might result in undefined behavior" - ) - submodule_inputs.append( - Input( - shape=[1], - torch_tensor=input_val, - dtype=input_val.dtype, - ) - ) - else: - submodule_inputs.append( - Input( - min_shape=inputs_map["min_shape"][idx].shape, - opt_shape=inputs_map["opt_shape"][idx].shape, - max_shape=inputs_map["max_shape"][idx].shape, - torch_tensor=inputs_map["opt_shape"][idx], - dtype=inputs_map["opt_shape"][idx].dtype, - ) - ) - - return submodule_inputs - - def get_graph_converter_support( graph_module: torch.fx.GraphModule, verbose: bool = DEBUG, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 29b01990ce..6ea9503b84 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -88,7 +88,8 @@ def get_torch_inputs( if isinstance(input, Input) ] return [ - input.torch_tensor.to(device) for input in inputs if isinstance(input, Input) + input.torch_tensor.to(device) if isinstance(input, Input) else input + for input in inputs ] diff --git a/tests/py/dynamo/conversion/test_sym_size.py b/tests/py/dynamo/conversion/test_sym_size.py new file mode 100644 index 0000000000..35bf75a509 --- /dev/null +++ b/tests/py/dynamo/conversion/test_sym_size.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestSymSizeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4),), + ] + ) + def test_sym_size_batch(self, input_shape): + class BatchDim(nn.Module): + def forward(self, x): + return torch.ops.aten.sym_size.int(x, 0) + + inputs = [torch.randn(*input_shape)] + self.run_test( + BatchDim(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 2, 4),), + ] + ) + def test_sym_size_non_batch(self, input_shape): + class NonBatchDim(nn.Module): + def forward(self, x): + return torch.ops.aten.sym_size.int(x, 1) + + inputs = [torch.randn(*input_shape)] + self.run_test( + NonBatchDim(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index ceb4a6dd2c..822ee468a9 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -3,9 +3,8 @@ import pytest import timm import torch -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity - import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -65,7 +64,7 @@ def forward(self, x): @pytest.mark.unit def test_base_dynamic_fallback(ir): """ - Tests the model (which is fully convertible) with dynamic shapes + Tests the model with dynamic shapes where torch.abs op is forced to run in PyTorch """ class MyModule(torch.nn.Module): @@ -114,3 +113,53 @@ def forward(self, x): with torch.no_grad(): torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_view(ir): + """ + Tests the model (which is fully convertible) with dynamic shapes + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + input_shape = x.size() + y = x.view(input_shape[0], -1) + return y + + model = MyModule().eval().cuda() + input = torch.randn((6, 3, 4)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 4), + opt_shape=(4, 3, 4), + max_shape=(8, 3, 4), + dtype=torch.float32, + name="x", + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache()