diff --git a/.circleci/config.yml b/.circleci/config.yml index d1e36447d3..89d4ae168f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -109,7 +109,7 @@ commands: sudo docker run --rm --runtime=nvidia --gpus all nvidia/cuda:11.6.2-base-ubuntu20.04 nvidia-smi install-cudnn: - description: "Install CUDNN 8.8.1" + description: "Install CUDNN 8.9.5" parameters: os: type: string @@ -119,7 +119,7 @@ commands: default: "x86_64" cudnn-version: type: string - default: "8.8.1.3" + default: "8.9.5.30" cuda-version: type: string default: "cuda12.0" @@ -198,7 +198,7 @@ commands: default: "cuda12.0" cudnn-version: type: string - default: "8.8.1.3" + default: "8.9.5.30" trt-version-short: type: string default: "8.6.1" @@ -246,7 +246,7 @@ commands: default: "8.6.1" cudnn-version-long: type: string - default: "8.8.1.3" + default: "8.9.5.30" steps: - run: name: Set up python environment @@ -269,10 +269,10 @@ commands: default: "0.16.0.dev20230703+cu121" torch-build-index: type: string - default: "https://download.pytorch.org/whl/nightly/cu121" + default: "https://download.pytorch.org/whl/cu121" torchvision-build-index: type: string - default: "https://download.pytorch.org/whl/nightly/cu121" + default: "https://download.pytorch.org/whl/cu121" steps: - run: name: Install Torch @@ -1457,10 +1457,10 @@ parameters: default: "0.16.0.dev20230703+cu121" torch-build-index: type: string - default: "https://download.pytorch.org/whl/nightly/cu121" + default: "https://download.pytorch.org/whl/cu121" cudnn-version: type: string - default: "8.8.1.3" + default: "8.9.5.30" trt-version-short: type: string default: "8.6.1" @@ -1483,7 +1483,7 @@ parameters: default: "https://download.pytorch.org/whl/cu117" cudnn-version-legacy: type: string - default: "8.8.1.3" + default: "8.9.5.30" trt-version-short-legacy: type: string default: "8.6.1" diff --git a/.github/workflows/docgen.yml b/.github/workflows/docgen.yml index 34160ca463..2eb6bf3616 100644 --- a/.github/workflows/docgen.yml +++ b/.github/workflows/docgen.yml @@ -30,14 +30,14 @@ jobs: - name: Install base deps run: | python3 -m pip install pip --upgrade - python3 -m pip install pyyaml numpy torch --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 + python3 -m pip install pyyaml numpy torch --pre --extra-index-url https://download.pytorch.org/whl/cu121 ./packaging/pre_build_script.sh - name: Get HEAD SHA id: vars run: echo "sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT - name: Build Python Package run: | - python3 -m pip install --pre . --extra-index-url https://download.pytorch.org/whl/nightly/cu121 + python3 -m pip install --pre . --extra-index-url https://download.pytorch.org/whl/cu121 - name: Generate New Docs run: | cd docsrc diff --git a/.github/workflows/docker_builder.yml b/.github/workflows/docker_builder.yml index 817bc87c82..99b6efe53e 100644 --- a/.github/workflows/docker_builder.yml +++ b/.github/workflows/docker_builder.yml @@ -6,6 +6,7 @@ on: branches: - main - nightly + - release/2.1 # If pushes to main are made in rapid succession, # cancel existing docker builds and use newer commits diff --git a/README.md b/README.md index 5d54c1ca11..a7100af133 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR - Bazel 6.2.1 - Libtorch 2.1.0 - CUDA 12.1 -- cuDNN 8.8.1 +- cuDNN 8.9.5 - TensorRT 8.6.1 ## Prebuilt Binaries and Wheel files diff --git a/WORKSPACE b/WORKSPACE index 668b46c779..2140deaf42 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -54,14 +54,14 @@ http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-cxx11-abi-shared-with-deps-latest.zip"], + urls = ["https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu121.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-shared-with-deps-latest.zip"], + urls = ["https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.1.0%2Bcu121.zip"], ) # Download these tarballs manually from the NVIDIA website @@ -71,10 +71,10 @@ http_archive( http_archive( name = "cudnn", build_file = "@//third_party/cudnn/archive:BUILD", - sha256 = "79d77a769c7e7175abc7b5c2ed5c494148c0618a864138722c887f95c623777c", - strip_prefix = "cudnn-linux-x86_64-8.8.1.3_cuda12-archive", + sha256 = "2a2eb89a2ab51071151c6082f1e816c702167a711a9372f9f73a7b5c4b06e01a", + strip_prefix = "cudnn-linux-x86_64-8.9.5.30_cuda12-archive", urls = [ - "https://developer.nvidia.com/downloads/compute/cudnn/secure/8.8.1/local_installers/12.0/cudnn-linux-x86_64-8.8.1.3_cuda12-archive.tar.xz", + "https://developer.nvidia.com/downloads/compute/cudnn/secure/8.9.5/local_installers/12.x/cudnn-linux-x86_64-8.9.5.30_cuda12-archive.tar.xz", ], ) diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index aeecd4deb0..cdf9b92de8 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,3 +1,3 @@ __cuda_version__: "12.1" -__cudnn_version__: "8.8" +__cudnn_version__: "8.9" __tensorrt_version__: "8.6" diff --git a/docker/README.md b/docker/README.md index 527b7ae2b2..9f83f25134 100644 --- a/docker/README.md +++ b/docker/README.md @@ -17,14 +17,14 @@ Note: By default the container uses the `pre-cxx11-abi` version of Torch + Torch ### Instructions -- The example below uses CUDNN 8.8 and TensorRT 8.6 +- The example below uses CUDNN 8.9 and TensorRT 8.6 - See dependencies for a list of current default dependencies. > From root of Torch-TensorRT repo Build: ``` -DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=8.6 --build-arg CUDNN_VERSION=8.8 -f docker/Dockerfile -t torch_tensorrt:latest . +DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=8.6 --build-arg CUDNN_VERSION=8.9 -f docker/Dockerfile -t torch_tensorrt:latest . ``` Run: diff --git a/docker/dist-build.sh b/docker/dist-build.sh index 0bc2fc09f9..fab62f21cc 100755 --- a/docker/dist-build.sh +++ b/docker/dist-build.sh @@ -3,9 +3,9 @@ TOP_DIR=$(cd $(dirname $0); pwd)/.. if [[ -z "${USE_CXX11}" ]]; then - BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -w dist" + BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/cu121 -w dist" else - BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -w dist" + BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/cu121 -w dist" fi # TensorRT restricts our pip version diff --git a/docsrc/index.rst b/docsrc/index.rst index 97580541ea..8aba9822df 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -81,6 +81,7 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_resnet_example tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage + tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion Python API Documenation ------------------------ diff --git a/docsrc/tutorials/images/majestic_castle.png b/docsrc/tutorials/images/majestic_castle.png new file mode 100644 index 0000000000..bac6073a90 Binary files /dev/null and b/docsrc/tutorials/images/majestic_castle.png differ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index fa863952e7..d895cc0113 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API +* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` diff --git a/examples/dynamo/torch_compile_stable_diffusion.py b/examples/dynamo/torch_compile_stable_diffusion.py new file mode 100644 index 0000000000..0511e5a363 --- /dev/null +++ b/examples/dynamo/torch_compile_stable_diffusion.py @@ -0,0 +1,55 @@ +""" +.. _torch_compile_stable_diffusion: + +Torch Compile Stable Diffusion +====================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a Stable Diffusion model. A sample output is featured below: + +.. image:: /tutorials/images/majestic_castle.png + :width: 512px + :height: 512px + :scale: 50 % + :align: right +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +from diffusers import DiffusionPipeline + +import torch_tensorrt + +model_id = "CompVis/stable-diffusion-v1-4" +device = "cuda:0" + +# Instantiate Stable Diffusion Pipeline with FP16 weights +pipe = DiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16 +) +pipe = pipe.to(device) + +backend = "torch_tensorrt" + +# Optimize the UNet portion with Torch-TensorRT +pipe.unet = torch.compile( + pipe.unet, + backend=backend, + options={ + "truncate_long_and_double": True, + "precision": torch.float16, + }, + dynamic=False, +) + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +prompt = "a majestic castle in the clouds" +image = pipe(prompt).images[0] + +image.save("images/majestic_castle.png") +image.show() diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 3d69b407f1..b95715b5ae 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -153,6 +153,7 @@ aten.transpose.int, aten.tril.default, aten.triu.default, + aten.unbind, aten.unfold, aten.unfold_backward, aten.unfold_copy, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 2443e33d50..65f6cea1ed 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -2,22 +2,11 @@ from typing import Any, Sequence import torch -from torch_tensorrt._utils import sanitized_torch_version +from torch._inductor.freezing import ConstantFolder, replace_node_with_constant from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) -from packaging import version - -# Modify import location of utilities based on Torch version -if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): - from torch._inductor.freezing import ConstantFolder, replace_node_with_constant -else: - from torch._inductor.constant_folding import ( - ConstantFolder, - replace_node_with_constant, - ) - logger = logging.getLogger(__name__) diff --git a/tools/perf/README.md b/tools/perf/README.md index 452621e939..4d4579efb4 100644 --- a/tools/perf/README.md +++ b/tools/perf/README.md @@ -3,9 +3,10 @@ This is a comprehensive Python benchmark suite to run perf runs using different supported backends. Following backends are supported: 1. Torch -2. Torch-TensorRT -3. FX-TRT -4. TensorRT +2. Torch-TensorRT [Torchscript] +3. Torch-TensorRT [Dynamo] +4. Torch-TensorRT [torch_compile] +5. TensorRT Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package. @@ -22,9 +23,6 @@ Benchmark scripts depends on following Python packages in addition to requiremen ``` ./ -├── config -│ ├── vgg16_trt.yml -│ └── vgg16.yml ├── models ├── perf_run.py ├── hub.py @@ -35,10 +33,8 @@ Benchmark scripts depends on following Python packages in addition to requiremen ``` - -* `config` - Directory which contains sample yaml configuration files for VGG network. * `models` - Model directory -* `perf_run.py` - Performance benchmarking script which supports torch, torch_tensorrt, fx2trt, tensorrt backends +* `perf_run.py` - Performance benchmarking script which supports torch, ts_trt, torch_compile, dynamo, tensorrt backends * `hub.py` - Script to download torchscript models for VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT * `custom_models.py` - Script which includes custom models other than torchvision and timm (eg: HF BERT) * `utils.py` - utility functions script @@ -46,76 +42,11 @@ Benchmark scripts depends on following Python packages in addition to requiremen ## Usage -There are two ways you can run a performance benchmark. - -### Using YAML config files - -To run the benchmark for a given configuration file: - -```python -python perf_run.py --config=config/vgg16.yml -``` - -There are two sample configuration files added. - -* vgg16.yml demonstrates a configuration with all the supported backends (Torch, Torch-TensorRT, TensorRT) -* vgg16_trt.yml demonstrates how to use an external TensorRT serialized engine file directly. - - -### Supported fields - -| Name | Supported Values | Description | -| ----------------- | ------------------------------------ | ------------------------------------------------------------ | -| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) | -| input | - | Input binding names. Expected to list shapes of each input bindings | -| model | - | Configure the model filename and name | -| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) | -| filename | - | Model file name to load from disk. | -| name | - | Model name | -| runtime | - | Runtime configurations | -| device | 0 | Target device ID to run inference. Range depends on available GPUs | -| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend | -| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision | - -Additional sample use case: - -``` -backend: - - torch - - torch_tensorrt - - tensorrt - - fx2trt -input: - input0: - - 3 - - 224 - - 224 - num_inputs: 1 -model: - filename: model.plan - name: vgg16 -model_torch: - filename: model_torch.pt - name: vgg16 -runtime: - device: 0 - precision: - - fp32 - - fp16 -``` - -Note: - -1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend. -2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module. - -### Using CompileSpec options via CLI - Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module -* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt -* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`) -* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend) +* `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt +* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module). +* `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend) * `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT * `--batch_size` : Batch size * `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16 @@ -131,10 +62,15 @@ Eg: --model_torch ${MODELS_DIR}/vgg16_torch.pt \ --precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \ --batch_size 1 \ - --backends torch,torch_tensorrt,tensorrt,fx2trt \ + --backends torch,ts_trt,dynamo,torch_compile,tensorrt \ --report "vgg_perf_bs1.txt" ``` +Note: + +1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend. +2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module. + ### Example models This tool benchmarks any pytorch model or torchscript module. As an example, we provide VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT models in `hub.py` that we internally test for performance. diff --git a/tools/perf/benchmark.sh b/tools/perf/benchmark.sh index fe7400945d..319c779bf5 100644 --- a/tools/perf/benchmark.sh +++ b/tools/perf/benchmark.sh @@ -6,17 +6,33 @@ MODELS_DIR="models" python hub.py batch_sizes=(1 2 4 8 16 32 64 128 256) +large_model_batch_sizes=(1 2 4 8 16 32 64) -#Benchmark VGG16 model + +# Benchmark VGG16 model echo "Benchmarking VGG16 model" for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \ - --model_torch ${MODELS_DIR}/vgg16_pytorch.pt \ + --model_torch vgg16 \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ - --backends torch,torch_tensorrt,tensorrt,fx2trt \ - --report "vgg_perf_bs${bs}.txt" + --truncate \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "vgg16_perf_bs${bs}.txt" +done + +# Benchmark AlexNet model +echo "Benchmarking AlexNet model" +for bs in ${batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/alexnet_scripted.jit.pt \ + --model_torch alexnet \ + --precision fp32,fp16 --inputs="(${bs}, 3, 227, 227)" \ + --batch_size ${bs} \ + --truncate \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "alexnet_perf_bs${bs}.txt" done # Benchmark Resnet50 model @@ -24,11 +40,12 @@ echo "Benchmarking Resnet50 model" for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \ - --model_torch ${MODELS_DIR}/resnet50_pytorch.pt \ + --model_torch resnet50 \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ - --backends torch,torch_tensorrt,tensorrt,fx2trt \ - --report "rn50_perf_bs${bs}.txt" + --truncate \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "resnet50_perf_bs${bs}.txt" done # Benchmark VIT model @@ -36,22 +53,50 @@ echo "Benchmarking VIT model" for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \ + --model_torch vit \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ - --backends torch,torch_tensorrt,tensorrt \ + --truncate \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ --report "vit_perf_bs${bs}.txt" done +# Benchmark VIT Large model +echo "Benchmarking VIT Large model" +for bs in ${large_model_batch_sizes[@]} +do + python perf_run.py --model ${MODELS_DIR}/vit_large_scripted.jit.pt \ + --model_torch vit_large \ + --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ + --truncate \ + --batch_size ${bs} \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "vit_large_perf_bs${bs}.txt" +done + # Benchmark EfficientNet-B0 model echo "Benchmarking EfficientNet-B0 model" for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \ - --model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \ + --model_torch efficientnet_b0 \ --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \ --batch_size ${bs} \ - --backends torch,torch_tensorrt,tensorrt,fx2trt \ - --report "eff_b0_perf_bs${bs}.txt" + --truncate \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ + --report "efficientnet_b0_perf_bs${bs}.txt" +done + +# Benchmark Stable Diffusion UNet model +echo "Benchmarking SD UNet model" +for bs in ${large_model_batch_sizes[@]} +do + python perf_run.py --model_torch sd_unet \ + --precision fp32,fp16 --inputs="(${bs}, 4, 128, 128)@fp16;(${bs})@fp16;(${bs}, 1, 768)@fp16" \ + --batch_size ${bs} \ + --backends torch,dynamo,torch_compile,inductor \ + --truncate \ + --report "sd_unet_perf_bs${bs}.txt" done # Benchmark BERT model @@ -59,9 +104,10 @@ echo "Benchmarking Huggingface BERT base model" for bs in ${batch_sizes[@]} do python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \ - --precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \ + --model_torch "bert_base_uncased" \ + --precision fp32,fp16 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \ --batch_size ${bs} \ - --backends torch,torch_tensorrt \ + --backends torch,ts_trt,dynamo,torch_compile,inductor \ --truncate \ --report "bert_base_perf_bs${bs}.txt" done diff --git a/tools/perf/config/vgg16.yml b/tools/perf/config/vgg16.yml deleted file mode 100755 index d88d489458..0000000000 --- a/tools/perf/config/vgg16.yml +++ /dev/null @@ -1,19 +0,0 @@ ---- -backend: - - all -input: - input0: - - 1 - - 3 - - 224 - - 224 - num_inputs: 1 - batch_size: 1 -model: - filename: models/vgg16_scripted.jit.pt - name: vgg16 -runtime: - device: 0 - precision: - - fp32 - - fp16 diff --git a/tools/perf/config/vgg16_trt.yml b/tools/perf/config/vgg16_trt.yml deleted file mode 100755 index 9392e5c24a..0000000000 --- a/tools/perf/config/vgg16_trt.yml +++ /dev/null @@ -1,20 +0,0 @@ ---- -backend: - - tensorrt -input: - input0: - - 3 - - 224 - - 224 - num_inputs: 1 -model: - filename: models/model.plan - name: vgg16 -calibration_cache: - - vgg16.cache -runtime: - device: 0 - precision: - - fp32 - - fp16 - - int8 diff --git a/tools/perf/custom_models.py b/tools/perf/custom_models.py index a8b8a5dae0..0f85957e1e 100644 --- a/tools/perf/custom_models.py +++ b/tools/perf/custom_models.py @@ -1,10 +1,18 @@ import torch -import torch.nn as nn -from transformers import BertModel, BertTokenizer, BertConfig -import torch.nn.functional as F def BertModule(): + from transformers import BertModel + + model_name = "bert-base-uncased" + model = BertModel.from_pretrained(model_name, torchscript=True) + model.eval() + return model + + +def BertInputs(): + from transformers import BertTokenizer + model_name = "bert-base-uncased" enc = BertTokenizer.from_pretrained(model_name) text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" @@ -15,16 +23,13 @@ def BertModule(): segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] tokens_tensor = torch.tensor([indexed_tokens]) segments_tensors = torch.tensor([segments_ids]) - config = BertConfig( - vocab_size_or_config_json_file=32000, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - torchscript=True, + return [tokens_tensor, segments_tensors] + + +def StableDiffusionUnet(): + from diffusers import DiffusionPipeline + + pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 ) - model = BertModel(config) - model.eval() - model = BertModel.from_pretrained(model_name, torchscript=True) - traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) - return traced_model + return pipe.unet diff --git a/tools/perf/hub.py b/tools/perf/hub.py index 4cedaabded..22b685d1fc 100644 --- a/tools/perf/hub.py +++ b/tools/perf/hub.py @@ -1,12 +1,7 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.models as models -import timm -from transformers import BertModel, BertTokenizer, BertConfig -import os import json -import custom_models as cm +import os + +import torch torch.hub._validate_not_a_forked_repo = lambda a, b, c: True @@ -25,25 +20,7 @@ VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all") # Key models selected for benchmarking with their respective paths -BENCHMARK_MODELS = { - "vgg16": { - "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), - "path": ["script", "pytorch"], - }, - "resnet50": { - "model": models.resnet50(weights=None), - "path": ["script", "pytorch"], - }, - "efficientnet_b0": { - "model": timm.create_model("efficientnet_b0", pretrained=True), - "path": ["script", "pytorch"], - }, - "vit": { - "model": timm.create_model("vit_base_patch16_224", pretrained=True), - "path": "script", - }, - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, -} +from utils import BENCHMARK_MODELS def get(n, m, manifest): @@ -51,42 +28,38 @@ def get(n, m, manifest): traced_filename = "models/" + n + "_traced.jit.pt" script_filename = "models/" + n + "_scripted.jit.pt" pytorch_filename = "models/" + n + "_pytorch.pt" - x = torch.ones((1, 3, 300, 300)).cuda() - if n == "bert_base_uncased": - traced_model = m["model"] - torch.jit.save(traced_model, traced_filename) + + m["model"] = m["model"].eval().cuda() + + # Get all desired model save specifications as list + paths = [m["path"]] if isinstance(m["path"], str) else m["path"] + + # Depending on specified model save specifications, save desired model formats + if any(path in ("all", "torchscript", "trace") for path in paths): + # (TorchScript) Traced model + trace_model = torch.jit.trace(m["model"], [inp.cuda() for inp in m["inputs"]]) + torch.jit.save(trace_model, traced_filename) manifest.update({n: [traced_filename]}) - else: - m["model"] = m["model"].eval().cuda() - - # Get all desired model save specifications as list - paths = [m["path"]] if isinstance(m["path"], str) else m["path"] - - # Depending on specified model save specifications, save desired model formats - if any(path in ("all", "torchscript", "trace") for path in paths): - # (TorchScript) Traced model - trace_model = torch.jit.trace(m["model"], [x]) - torch.jit.save(trace_model, traced_filename) - manifest.update({n: [traced_filename]}) - if any(path in ("all", "torchscript", "script") for path in paths): - # (TorchScript) Scripted model - script_model = torch.jit.script(m["model"]) - torch.jit.save(script_model, script_filename) - if n in manifest.keys(): - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] - files.append(script_filename) - manifest.update({n: files}) - else: - manifest.update({n: [script_filename]}) - if any(path in ("all", "pytorch") for path in paths): - # (PyTorch Module) model - torch.save(m["model"], pytorch_filename) - if n in manifest.keys(): - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] - files.append(script_filename) - manifest.update({n: files}) - else: - manifest.update({n: [script_filename]}) + if any(path in ("all", "torchscript", "script") for path in paths): + # (TorchScript) Scripted model + script_model = torch.jit.script(m["model"]) + torch.jit.save(script_model, script_filename) + if n in manifest.keys(): + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] + files.append(script_filename) + manifest.update({n: files}) + else: + manifest.update({n: [script_filename]}) + if any(path in ("all", "pytorch") for path in paths): + # (PyTorch Module) model + torch.save(m["model"], pytorch_filename) + if n in manifest.keys(): + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] + files.append(script_filename) + manifest.update({n: files}) + else: + manifest.update({n: [script_filename]}) + return manifest @@ -94,9 +67,25 @@ def download_models(version_matches, manifest): # Download all models if torch version is different than model version if not version_matches: for n, m in BENCHMARK_MODELS.items(): + # Ensure all specified desired model formats exist and are valid + assert all( + (path in VALID_PATHS) + for path in ( + m["path"] if isinstance(m["path"], (list, tuple)) else [m["path"]] + ) + ), "Not all 'path' attributes in BENCHMARK_MODELS are valid" + manifest = get(n, m, manifest) else: for n, m in BENCHMARK_MODELS.items(): + # Ensure all specified desired model formats exist and are valid + assert all( + (path in VALID_PATHS) + for path in ( + m["path"] if isinstance(m["path"], (list, tuple)) else [m["path"]] + ) + ), "Not all 'path' attributes in BENCHMARK_MODELS are valid" + scripted_filename = "models/" + n + "_scripted.jit.pt" traced_filename = "models/" + n + "_traced.jit.pt" pytorch_filename = "models/" + n + "_pytorch.pt" @@ -174,12 +163,4 @@ def main(): if __name__ == "__main__": - # Ensure all specified desired model formats exist and are valid - paths = [ - [m["path"]] if isinstance(m["path"], str) else m["path"] - for m in BENCHMARK_MODELS.values() - ] - assert all( - (path in VALID_PATHS) for path_list in paths for path in path_list - ), "Not all 'path' attributes in BENCHMARK_MODELS are valid" main() diff --git a/tools/perf/models/.gitkeep b/tools/perf/models/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 65729008aa..58d03de4d1 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -1,68 +1,49 @@ -from __future__ import print_function -from __future__ import absolute_import -from __future__ import division +from __future__ import absolute_import, division, print_function +# Config parsers and report generations +import argparse +import logging +import os import time import timeit import warnings -import numpy as np -import torch.backends.cudnn as cudnn +from functools import wraps -# Config parsers and report generations -import argparse -import yaml -import os +import numpy as np import pandas as pd +import tensorrt as trt # Importing supported Backends import torch -import torch_tensorrt as torchtrt - -# from torch_tensorrt.fx.lower import compile -# from torch_tensorrt.fx.utils import LowerPrecision - -import tensorrt as trt from utils import ( - parse_inputs, + BENCHMARK_MODELS, parse_backends, - precision_to_dtype, + parse_inputs, parse_precisions, - BENCHMARK_MODELS, + precision_to_dtype, ) +import torch_tensorrt as torchtrt + WARMUP_ITER = 10 results = [] -# YAML Parser class for parsing the run configurations -class ConfigParser: - def __init__(self, config_file): - self.parser = None - self.config = config_file - self.params = None - - # Reads and loads the yaml file - def read_config(self): - with open(self.config, "r") as stream: - try: - self.params = yaml.safe_load(stream) - except yaml.YAMLError as exc: - print(exc) - return self.params - - # Retrieves the value from the configuration else uses default values - def get(self, key, default_value=None): - if not key in self.params: - if not default_value: - raise ValueError( - "Key {} is not present and default_value is not configured. Please run it with default value", - key, - ) - self.params[key] = default_value - return self.params[key] +def run_with_try_except(func): + @wraps(func) + def wrapper_func(*args, **kwargs): + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + raise + except: + logging.warning(f"Running {func} failed", exc_info=True) + + return wrapper_func # Runs inference using Torch backend +@run_with_try_except def run_torch(model, input_tensors, params, precision, batch_size): print("Running Torch for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) @@ -88,9 +69,8 @@ def run_torch(model, input_tensors, params, precision, batch_size): # Runs inference using Torch-TensorRT backend -def run_torch_tensorrt( - model, input_tensors, params, precision, truncate_long_and_double, batch_size -): +@run_with_try_except +def run_ts_trt(model, input_tensors, params, precision, batch_size): print( "Running Torch-TensorRT for precision: ", precision, @@ -101,16 +81,16 @@ def run_torch_tensorrt( compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, - "truncate_long_and_double": truncate_long_and_double, + "truncate_long_and_double": params.get("truncate", False), } if precision == "int8": compile_settings.update({"calib": params.get("calibration_cache")}) start_compile = time.time_ns() - model = torchtrt.compile(model, **compile_settings) + model = torchtrt.compile(model, ir="ts", **compile_settings) end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 + compile_time_s = (end_compile - start_compile) / 1e9 iters = params.get("iterations", 20) # Warm up @@ -130,27 +110,83 @@ def run_torch_tensorrt( meas_time = end_time - start_time timings.append(meas_time) - recordStats("Torch-TensorRT", timings, precision, batch_size, compile_time_ms) - + recordStats( + "Torch-TensorRT [Torchscript]", timings, precision, batch_size, compile_time_s + ) -# Runs inference using FX2TRT backend -def run_fx2trt(model, input_tensors, params, precision, batch_size): - print("Running FX2TRT for precision: ", precision, " batch_size : ", batch_size) - if precision == "fp16": - model.half() - input_tensors = [tensor.half() for tensor in input_tensors] - # Run lowering eager mode benchmark +@run_with_try_except +def run_dynamo(model, input_tensors, params, precision, batch_size): + """ + Compile the given model using Torch-TensorRT dynamo frontend and record performance stats + """ + print( + "Running Torch-TensorRT [dynamo] for precision: ", + precision, + " batch_size : ", + batch_size, + ) start_compile = time.time_ns() model = torchtrt.compile( model, - ir="fx", inputs=input_tensors, - enabled_precisions={torch.float16 if precision == "fp16" else torch.float32}, + ir="dynamo", + enabled_precisions={precision_to_dtype(precision)}, + min_block_size=params.get("min_block_size", 1), + debug=False, + truncate_long_and_double=params.get("truncate", False), ) end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 + compile_time_s = (end_compile - start_compile) / 1e9 + iters = params.get("iterations", 20) + # Warm up + with torch.no_grad(): + for _ in range(WARMUP_ITER): + features = model(*input_tensors) + + torch.cuda.synchronize() + + timings = [] + with torch.no_grad(): + for i in range(iters): + start_time = timeit.default_timer() + features = model(*input_tensors) + torch.cuda.synchronize() + end_time = timeit.default_timer() + meas_time = end_time - start_time + timings.append(meas_time) + + recordStats( + "Torch-TensorRT [Dynamo]", timings, precision, batch_size, compile_time_s + ) + +@run_with_try_except +def run_torch_compile(model, input_tensors, params, precision, batch_size): + """ + Compile the given model using Torch-TensorRT torch.compile frontend and record performance stats + """ + torch._dynamo.reset() + + print( + "Running Torch-TensorRT [torch_compile] for precision: ", + precision, + " batch_size : ", + batch_size, + ) + compile_spec = { + "inputs": input_tensors, + "enabled_precisions": {precision_to_dtype(precision)}, + "truncate_long_and_double": params.get("truncate", False), + "min_block_size": params.get("min_block_size", 1), + } + start_compile = time.time_ns() + model = torch.compile( + model, backend="tensorrt", dynamic=False, options=compile_spec + ) + model(*input_tensors) + end_compile = time.time_ns() + compile_time_s = (end_compile - start_compile) / 1e9 iters = params.get("iterations", 20) # Warm up with torch.no_grad(): @@ -168,70 +204,63 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size): end_time = timeit.default_timer() meas_time = end_time - start_time timings.append(meas_time) + # Reset torch dynamo cache + torch._dynamo.reset() - recordStats("FX-TensorRT", timings, precision, batch_size, compile_time_ms) + recordStats( + "Torch-TensorRT [torch_compile]", + timings, + precision, + batch_size, + compile_time_s, + ) -def run_dynamo(model, input_tensors, params, precision, batch_size): - dynamo_backend = params["dynamo_backend"] +@run_with_try_except +def run_inductor(model, input_tensors, params, precision, batch_size): + """ + Compile the given model using torch inductor and record performance stats + """ + torch._dynamo.reset() + print( - "Running Dynamo with backend: ", - dynamo_backend, - " for precision: ", + "Running Torch [inductor] for precision: ", precision, " batch_size : ", batch_size, ) - if precision == "fp16": - input_tensors = [tensor.half() for tensor in input_tensors] - - fp16_mode = True if precision == "fp16" else False - # dynamo_backend_params = {"fp16_mode" : fp16_mode} - # model = torch.compile( - # model, - # mode="default", - # dynamic=False, - # fullgraph=False, - # backend=dynamo_backend, - # # **dynamo_backend_params - # ) - import torch._dynamo as dynamo - - model = dynamo.optimize(dynamo_backend, nopython=True)(model) - # Compile and measure the time + start_compile = time.time_ns() + model = torch.compile(model, backend="inductor", dynamic=False, mode="max-autotune") + model(*input_tensors) + end_compile = time.time_ns() + compile_time_s = (end_compile - start_compile) / 1e9 + iters = params.get("iterations", 20) + # Warm up with torch.no_grad(): - start_compile = time.time_ns() - features = model(*input_tensors) - end_compile = time.time_ns() - compile_time_ms = (end_compile - start_compile) / 1e6 - iters = params.get("iterations", 20) - # import pdb; pdb.set_trace() - print("============= DONE 0 ==================") - - print("============= DONE 1 ==================") - # Warm up - model = torch._dynamo.run(model) - # import pdb; pdb.set_trace() - - exported_model, _ = torch._dynamo.export(model, *input_tensors) - for i in range(WARMUP_ITER): - print("==== ITER: ", i) - features = exported_model(*input_tensors) + for _ in range(WARMUP_ITER): + features = model(*input_tensors) + + torch.cuda.synchronize() - torch.cuda.synchronize() - print("============= DONE 2 ==================") - timings = [] + timings = [] + with torch.no_grad(): for i in range(iters): start_time = timeit.default_timer() - features = exported_model(*input_tensors) + features = model(*input_tensors) torch.cuda.synchronize() end_time = timeit.default_timer() meas_time = end_time - start_time timings.append(meas_time) + # Reset torch dynamo cache + torch._dynamo.reset() recordStats( - "Dynamo-" + dynamo_backend, timings, precision, batch_size, compile_time_ms + "Torch [inductor]", + timings, + precision, + batch_size, + compile_time_s, ) @@ -259,12 +288,12 @@ def torch_device_from_trt(device): return TypeError("%s is not supported by torch" % device) +@run_with_try_except def run_tensorrt( model, input_tensors, params, precision, - truncate_long_and_double=False, is_trt_engine=False, batch_size=1, ): @@ -276,7 +305,7 @@ def run_tensorrt( compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, - "truncate_long_and_double": truncate_long_and_double, + "truncate_long_and_double": params.get("truncate", False), } print("Converting method to TensorRT engine...") @@ -332,7 +361,6 @@ def run( input_tensors, params, precision, - truncate_long_and_double=False, batch_size=1, is_trt_engine=False, model_torch=None, @@ -347,13 +375,13 @@ def run( if ( backend == "all" - or backend == "torch_tensorrt" + or backend == "ts_trt" or params.get("calibration_cache", None) == None ): print("int8 precision expects calibration cache file for inference") return False - if (model is None) and (backend != "fx2trt"): + if (model is None) and (backend in ("tensorrt", "ts_trt", "all")): warnings.warn( f"Requested backend {backend} without specifying a TorchScript Model, " + "skipping this backend" @@ -369,12 +397,11 @@ def run( if backend == "all": run_torch(model, input_tensors, params, precision, batch_size) - run_torch_tensorrt( + run_ts_trt( model, input_tensors, params, precision, - truncate_long_and_double, batch_size, ) run_tensorrt( @@ -382,65 +409,46 @@ def run( input_tensors, params, precision, - truncate_long_and_double, is_trt_engine, batch_size, ) - run_fx2trt(model_torch, input_tensors, params, precision, batch_size) run_dynamo(model_torch, input_tensors, params, precision, batch_size) - elif backend == "torchscript": - run_torch(model, input_tensors, params, precision, batch_size) - run_torch_tensorrt( - model, - input_tensors, - params, - precision, - truncate_long_and_double, - batch_size, - ) - run_tensorrt( - model, - input_tensors, - params, - precision, - truncate_long_and_double, - is_trt_engine, - batch_size, - ) - elif backend == "torch": - run_torch(model, input_tensors, params, precision, batch_size) + run_torch(model_torch, input_tensors, params, precision, batch_size) - elif backend == "torch_tensorrt": - run_torch_tensorrt( + elif backend == "ts_trt": + run_ts_trt( model, input_tensors, params, precision, - truncate_long_and_double, batch_size, ) - - elif backend == "fx2trt": - run_fx2trt(model_torch, input_tensors, params, precision, batch_size) - elif backend == "tensorrt": run_tensorrt( model, input_tensors, params, precision, - truncate_long_and_double, is_trt_engine, batch_size, ) elif backend == "dynamo": run_dynamo(model_torch, input_tensors, params, precision, batch_size) + elif backend == "torch_compile": + run_torch_compile(model_torch, input_tensors, params, precision, batch_size) + + elif backend == "torch_compile": + run_torch_compile(model_torch, input_tensors, params, precision, batch_size) + + elif backend == "inductor": + run_inductor(model_torch, input_tensors, params, precision, batch_size) + # Generate report -def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None): +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): times = np.array(timings) steps = len(times) speeds = batch_size / times @@ -459,72 +467,21 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None) "Mean(FPS)": speed_mean, "Median-Latency(ms)": time_med * 1000, "Mean-Latency(ms)": time_mean * 1000, - "Compile Time(ms)": compile_time_ms, + "Latency-StdDev(ms)": time_std * 1000, + "Compile Time(s)": compile_time_s, } results.append(stats) -def load_ts_model(params): - model = None - is_trt_engine = False - - # No TorchScript Model Specified - if len(params.get("model", "")) == 0: - return None, None, is_trt_engine - - # Load torch model traced/scripted - model_file = params.get("model").get("filename") - try: - model_name = params.get("model").get("name") - except: - model_name = model_file - - print("Loading model: ", model_file) - if model_file.endswith(".plan"): - is_trt_engine = True - # Read the TensorRT engine file - with open(model_file, "rb") as fin: - model = fin.read() - else: - model = torch.jit.load(model_file).cuda() - - return model, model_name, is_trt_engine - - -def load_torch_model(params): - model = None - - # No Torch Model Specified - if len(params.get("model_torch", "")) == 0: - return None, None - - # Load torch model - model_file = params.get("model_torch").get("filename") - try: - model_name = params.get("model_torch").get("name") - except: - model_name = model_file - - print("Loading Torch model: ", model_file) - model = torch.load(model_file).cuda() - - return model, model_name - - if __name__ == "__main__": arg_parser = argparse.ArgumentParser( description="Run inference on a model with random input values" ) - arg_parser.add_argument( - "--config", - type=str, - help="Load YAML based configuration file to run the inference. If this is used other params will be ignored", - ) # The following options are manual user provided settings arg_parser.add_argument( "--backends", type=str, - help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt", + help="Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, tensorrt", ) arg_parser.add_argument( "--model", type=str, default="", help="Name of torchscript model file" @@ -533,7 +490,7 @@ def load_torch_model(params): "--model_torch", type=str, default="", - help="Name of torch model file (used for fx2trt)", + help="Name of torch model file", ) arg_parser.add_argument( "--inputs", @@ -553,22 +510,17 @@ def load_torch_model(params): "--calibration_cache", type=str, help="Name of the calibration cache file" ) arg_parser.add_argument("--device", type=int, help="device id") + arg_parser.add_argument("--min_block_size", type=int, default=1, help="device id") arg_parser.add_argument( "--truncate", action="store_true", - help="Truncate long and double weights in the network in Torch-TensorRT", + help="Truncate long and double weights in the network in Torch-TensorRT", ) arg_parser.add_argument( "--is_trt_engine", action="store_true", help="Boolean flag to determine if the user provided model is a TRT engine or not", ) - arg_parser.add_argument( - "--dynamo_backend", - type=str, - default="fx2trt", - help="List of backends to use in Torchdynamo. Select options: inductor|fx2trt", - ) arg_parser.add_argument( "--report", type=str, @@ -576,132 +528,66 @@ def load_torch_model(params): ) args = arg_parser.parse_args() - cudnn.benchmark = True # Create random input tensor of certain size torch.manual_seed(12345) model_name = "Model" - if args.config: - parser = ConfigParser(args.config) - # Load YAML params - params = parser.read_config() - model, model_name, is_trt_engine = load_ts_model(params) - model_torch, model_name_torch = load_torch_model(params) - - # If neither model type was provided - if (model is None) and (model_torch is None): - raise ValueError( - "No valid models specified. Please provide a torchscript model file or model name " - + "(among the following options vgg16|resnet50|efficientnet_b0|vit) " - + "or provide a torch model file" - ) - - # Default device is set to 0. Configurable using yaml config file. - torch.cuda.set_device(params.get("runtime").get("device", 0)) + params = vars(args) + model_name = params["model"] + model = None - num_input = params.get("input").get("num_inputs") - truncate_long_and_double = params.get("runtime").get( - "truncate_long_and_double", False + model_name_torch = params["model_torch"] + model_torch = None + + # Load TorchScript model, if provided + if os.path.exists(model_name): + print("Loading user provided torchscript model: ", model_name) + model = torch.jit.load(model_name).cuda().eval() + + # Load PyTorch Model, if provided + if len(model_name_torch) > 0 and os.path.exists(model_name_torch): + print("Loading user provided torch model: ", model_name_torch) + model_torch = torch.load(model_name_torch).eval().cuda() + elif model_name_torch in BENCHMARK_MODELS: + model_torch = BENCHMARK_MODELS[model_name_torch]["model"].eval().cuda() + + # If neither model type was provided + if (model is None) and (model_torch is None): + raise ValueError( + "No valid models specified. Please provide a torchscript model file or model name " + + "(among the following options vgg16|resnet50|efficientnet_b0|vit) " + + "or provide a torch model file" ) - batch_size = params.get("input").get("batch_size", 1) - for precision in params.get("runtime").get("precision", "fp32"): - input_tensors = [] - num_input = params.get("input").get("num_inputs", 1) - for i in range(num_input): - inp_tensor = params.get("input").get("input" + str(i)) - input_tensors.append( - torch.randint( - 0, - 2, - tuple(d for d in inp_tensor), - dtype=precision_to_dtype(precision), - ).cuda() - ) - - if is_trt_engine: - print( - "Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" - ) - - if not is_trt_engine and (precision == "fp16" or precision == "half"): - # If model is TensorRT serialized engine then model.half will report failure - if model is not None: - model = model.half() - if model_torch is not None: - model_torch = model_torch.half() - backends = params.get("backend") - # Run inference - status = run( - model, - backends, - input_tensors, - params, - precision, - truncate_long_and_double, - batch_size, - is_trt_engine, - model_torch, - ) - else: - params = vars(args) - model_name = params["model"] - model = None - - model_name_torch = params["model_torch"] - model_torch = None - - # Load TorchScript model, if provided - if os.path.exists(model_name): - print("Loading user provided torchscript model: ", model_name) - model = torch.jit.load(model_name).cuda().eval() - elif model_name in BENCHMARK_MODELS: - print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name) - model = BENCHMARK_MODELS[model_name]["model"].eval().cuda() - - # Load PyTorch Model, if provided - if len(model_name_torch) > 0 and os.path.exists(model_name_torch): - print("Loading user provided torch model: ", model_name_torch) - model_torch = torch.load(model_name_torch).eval().cuda() - - # If neither model type was provided - if (model is None) and (model_torch is None): - raise ValueError( - "No valid models specified. Please provide a torchscript model file or model name " - + "(among the following options vgg16|resnet50|efficientnet_b0|vit) " - + "or provide a torch model file" - ) - - backends = parse_backends(params["backends"]) - if "dynamo" in backends and (model_torch is None): - raise ValueError( - "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument" - ) + backends = parse_backends(params["backends"]) + if ("dynamo" in backends or "torch_compile" in backends) and (model_torch is None): + raise ValueError( + "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument" + ) - truncate_long_and_double = params["truncate"] - batch_size = params["batch_size"] - is_trt_engine = params["is_trt_engine"] - precisions = parse_precisions(params["precision"]) + batch_size = params["batch_size"] + is_trt_engine = params["is_trt_engine"] + precisions = parse_precisions(params["precision"]) - for precision in precisions: - input_tensors = parse_inputs( - params["inputs"], precision_to_dtype(precision) - ) + for precision in precisions: + input_tensors = parse_inputs(params["inputs"], precision_to_dtype(precision)) - if not is_trt_engine and (precision == "fp16" or precision == "half"): - # If model is TensorRT serialized engine then model.half will report failure + if not is_trt_engine and (precision == "fp16" or precision == "half"): + # If model is TensorRT serialized engine then model.half will report failure + if model is not None: model = model.half() - - status = run( - model, - backends, - input_tensors, - params, - precision, - truncate_long_and_double, - batch_size, - is_trt_engine, - model_torch=model_torch, - ) + if model_torch is not None: + model_torch = model_torch.half() + + status = run( + model, + backends, + input_tensors, + params, + precision, + batch_size, + is_trt_engine, + model_torch=model_torch, + ) # Generate report print("Model Summary: ", model_name) diff --git a/tools/perf/requirements.txt b/tools/perf/requirements.txt index f9f8813feb..d204d3c335 100644 --- a/tools/perf/requirements.txt +++ b/tools/perf/requirements.txt @@ -1,5 +1,7 @@ -timeit numpy argparse -yaml +pyyaml +transformers==4.33.2 +diffusers==0.21.4 pandas==2.0.1 +timm==0.9.8 diff --git a/tools/perf/utils.py b/tools/perf/utils.py index 96a13ffbc2..a6f8ba236d 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -1,30 +1,79 @@ -import torch -import torch_tensorrt import custom_models as cm -import torchvision.models as models import timm +import torch +import torchvision.models as models -BENCHMARK_MODELS = { - "vgg16": { - "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), - "path": ["script", "pytorch"], - }, - "resnet50": { - "model": models.resnet50(weights=None), - "path": ["script", "pytorch"], - }, - "efficientnet_b0": { - "model": timm.create_model("efficientnet_b0", pretrained=True), - "path": ["script", "pytorch"], - }, - "vit": { - "model": timm.create_model("vit_base_patch16_224", pretrained=True), - "path": "script", - }, - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, +BENCHMARK_MODEL_NAMES = { + "vgg16", + "alexnet", + "resnet50", + "efficientnet_b0", + "vit", + "vit_large", + "bert_base_uncased", + "sd_unet", } +class ModelStorage: + def __contains__(self, name: str): + return name in BENCHMARK_MODEL_NAMES + + def __getitem__(self, name: str): + assert name in BENCHMARK_MODEL_NAMES + + if name == "vgg16": + return { + "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), + "path": ["script", "pytorch"], + } + elif name == "alexnet": + return { + "model": models.alexnet(weights=models.AlexNet_Weights.DEFAULT), + "path": ["script", "pytorch"], + } + elif name == "resnet50": + return { + "model": models.resnet50(weights=None), + "path": ["script", "pytorch"], + } + elif name == "efficientnet_b0": + return { + "model": timm.create_model("efficientnet_b0", pretrained=True), + "path": ["script", "pytorch"], + } + elif name == "vit": + return { + "model": timm.create_model("vit_base_patch16_224", pretrained=True), + "path": ["script", "pytorch"], + } + elif name == "vit_large": + return { + "model": timm.create_model("vit_giant_patch14_224", pretrained=False), + "path": ["script", "pytorch"], + } + elif name == "bert_base_uncased": + return { + "model": cm.BertModule(), + "inputs": cm.BertInputs(), + "path": ["trace", "pytorch"], + } + elif name == "sd_unet": + return { + "model": cm.StableDiffusionUnet(), + "path": "pytorch", + } + else: + raise AssertionError(f"Invalid model name {name}") + + def items(self): + for name in BENCHMARK_MODEL_NAMES: + yield name, self.__getitem__(name) + + +BENCHMARK_MODELS = ModelStorage() + + def precision_to_dtype(pr): if pr == "fp32": return torch.float @@ -51,7 +100,16 @@ def parse_inputs(user_inputs, dtype): ) for input_dim in input_shape_and_dtype[0][1:-1].split(","): input_shape.append(int(input_dim)) - torchtrt_inputs.append(torch.randint(0, 5, input_shape, dtype=dtype).cuda()) + + if input_shape != [1]: + if dtype == torch.int32: + torchtrt_inputs.append( + torch.randint(0, 5, input_shape, dtype=dtype).cuda() + ) + else: + torchtrt_inputs.append(torch.randn(input_shape, dtype=dtype).cuda()) + else: + torchtrt_inputs.append(torch.Tensor([1.0]).cuda()) return torchtrt_inputs