diff --git a/.circleci/config.yml b/.circleci/config.yml index 347dd77294..5649228384 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -263,7 +263,7 @@ commands: parameters: torch-build: type: string - default: "2.0.0.dev20230129+cu117" + default: "2.0.0.dev20230219+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" @@ -1026,7 +1026,7 @@ parameters: # Nightly platform config torch-build: type: string - default: "2.0.0.dev20230129+cu117" + default: "2.0.0.dev20230219+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" diff --git a/examples/fx/lower_example.py b/examples/fx/lower_example.py index cd9215712b..81c1cd28bc 100644 --- a/examples/fx/lower_example.py +++ b/examples/fx/lower_example.py @@ -188,6 +188,7 @@ def run_configuration_benchmark( input, max_batch_size=conf.batch_size, lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32, + explicit_batch_dimension=True, ) time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input)) else: diff --git a/examples/fx/lower_example_aten.py b/examples/fx/lower_example_aten.py new file mode 100644 index 0000000000..09a8e7cb85 --- /dev/null +++ b/examples/fx/lower_example_aten.py @@ -0,0 +1,196 @@ +import typing as t +from copy import deepcopy +from dataclasses import dataclass, field, replace + +import torch +import torchvision +from torch_tensorrt.fx import compile +from torch_tensorrt.fx.utils import LowerPrecision + + +""" +The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model +to TensorRT conveniently with lower.py. +""" + + +@dataclass +class Configuration: + """ + Specify the configuration used for fx2trt lowering and benchmark. + + To extend, add a new configuration field to this class, and modify the + lowering or benchmark behavior in `run_configuration_benchmark()` + correspondingly. + + It automatically prints all its values thanks to being a dataclass. + """ + + # number of inferences to run + batch_iter: int + + # Input batch size + batch_size: int + + # Friendly name of the configuration + name: str = "" + + # Whether to apply TRT lowering to the model before benchmarking + trt: bool = False + + # Whether to apply engine holder to the lowered model + jit: bool = False + + # Whether to enable FP16 mode for TRT lowering + fp16: bool = False + + # Relative tolerance for accuracy check after lowering. -1 means do not + # check accuracy. + accuracy_rtol: float = -1 # disable + + +@dataclass +class Result: + """Holds and computes the benchmark results. + + Holds raw essential benchmark result values like duration. + Also computes results that can be derived from the raw essential values + (QPS), in the form of auto properties. + + """ + + module: torch.nn.Module = field(repr=False) + input: t.Any = field(repr=False) + conf: Configuration + time_sec: float + accuracy_res: t.Optional[bool] = None + + @property + def time_per_iter_ms(self) -> float: + return self.time_sec * 1.0e3 + + @property + def qps(self) -> float: + return self.conf.batch_size / self.time_sec + + def format(self) -> str: + return ( + f"== Benchmark Result for: {self.conf}\n" + f"BS: {self.conf.batch_size}, " + f"Time per iter: {self.time_per_iter_ms:.2f}ms, " + f"QPS: {self.qps:.2f}, " + f"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})" + ) + + +@torch.inference_mode() +def benchmark( + model, + inputs, + batch_iter: int, + batch_size: int, +) -> None: + """ + Run fx2trt lowering and benchmark the given model according to the + specified benchmark configuration. Prints the benchmark result for each + configuration at the end of the run. + """ + + model = model.cuda().eval() + inputs = [x.cuda() for x in inputs] + + # benchmark base configuration + conf = Configuration(batch_iter=batch_iter, batch_size=batch_size) + + configurations = [ + # Baseline + replace(conf, name="CUDA Eager", trt=False), + # FP16 + replace( + conf, + name="TRT FP16 Eager", + trt=True, + jit=False, + fp16=True, + accuracy_rtol=1e-2, + ), + ] + + results = [ + run_configuration_benchmark(deepcopy(model), inputs, conf_) + for conf_ in configurations + ] + + for res in results: + print(res.format()) + + +def benchmark_torch_function(iters: int, f, *args) -> float: + """Estimates the average time duration for a single inference call in second + + If the input is batched, then the estimation is for the batches inference call. + + Args: + iters: number of inference iterations to run + f: a function to perform a single inference call + + Returns: + estimated average time duration in second for a single inference call + """ + with torch.inference_mode(): + f(*args) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + print("== Start benchmark iterations") + with torch.inference_mode(): + start_event.record() + for _ in range(iters): + f(*args) + end_event.record() + torch.cuda.synchronize() + print("== End benchmark iterations") + return (start_event.elapsed_time(end_event) * 1.0e-3) / iters + + +def run_configuration_benchmark( + module, + input, + conf: Configuration, +) -> Result: + """ + Runs `module` through lowering logic and benchmark the module before and + after lowering. + """ + print(f"=== Running benchmark for: {conf}", "green") + time = -1.0 + + if conf.fp16: + module = module.half() + input = [i.half() for i in input] + + if not conf.trt: + # Run eager mode benchmark + time = benchmark_torch_function(conf.batch_iter, lambda: module(*input)) + elif not conf.jit: + # Run lowering eager mode benchmark + lowered_module = compile( + module, + input, + max_batch_size=conf.batch_size, + lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32, + explicit_batch_dimension=True, + is_aten=True, + ) + time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input)) + else: + print("Lowering with JIT is not available!", "red") + + result = Result(module=module, input=input, conf=conf, time_sec=time) + return result + + +if __name__ == "__main__": + test_model = torchvision.models.resnet18(pretrained=True) + input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined] + benchmark(test_model, input, 50, 128) diff --git a/py/setup.py b/py/setup.py index f8a64f6571..bf7f8d0f0b 100644 --- a/py/setup.py +++ b/py/setup.py @@ -353,6 +353,7 @@ def run(self): "torch_tensorrt.fx.passes", "torch_tensorrt.fx.tools", "torch_tensorrt.fx.tracer.acc_tracer", + "torch_tensorrt.fx.tracer.dispatch_tracer", ] package_dir = { "torch_tensorrt.fx": "torch_tensorrt/fx", @@ -360,6 +361,7 @@ def run(self): "torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes", "torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools", "torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer", + "torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer", } with open("README.md", "r", encoding="utf-8") as fh: diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 0ca4383f6e..00063c3e21 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -165,6 +165,7 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph torch.ops.aten.max_pool3d_with_indices.default, torch.ops.aten.native_batch_norm.default, torch.ops.aten._native_batch_norm_legit.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, ): modified = True if len(n.users) != 1: @@ -185,6 +186,16 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph new_args = list(n.args) new_args.append(False) new_args = tuple(new_args) + elif ( + n.target == torch.ops.aten._native_batch_norm_legit_no_training.default + ): + new_op = torch.ops.aten.batch_norm + new_args = list(n.args) + new_args.append(False) + # _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false) + # but batchnorm takes in a training arg at position 5. + new_args.insert(5, False) + new_args = tuple(new_args) getitem_node = next(iter(n.users)) with module.graph.inserting_after(getitem_node):