diff --git a/.circleci/config.yml b/.circleci/config.yml index 3ff07e15cf..81b478372e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -244,6 +244,9 @@ commands: trt-version-long: type: string default: "8.5.1.7" + cudnn-version-long: + type: string + default: "8.5.0.96" steps: - run: name: Set up python environment @@ -252,7 +255,7 @@ commands: pip3 install wheel setuptools pip3 install nvidia-pyindex pip3 install tabulate - pip3 install tensorrt==<< parameters.trt-version-long >> + pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >> pip3 install pytest parameterized expecttest nox install-torch-from-index: @@ -260,7 +263,7 @@ commands: parameters: torch-build: type: string - default: "1.14.0.dev20221114+cu117" + default: "2.0.0.dev20230103+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" @@ -735,9 +738,9 @@ jobs: - create-py-env: trt-version-long: << parameters.trt-version-long >> - install-cudnn - - run: - name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" - command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH + # - run: + # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" + # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH - run: name: "Install torch-tensorrt" command: pip3 install --pre /tmp/dist/x86_64-linux/* @@ -989,7 +992,7 @@ parameters: # Nightly platform config torch-build: type: string - default: "1.14.0.dev20221114+cu117" + default: "2.0.0.dev20230103+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" @@ -1146,4 +1149,4 @@ workflows: torch-build: << pipeline.parameters.torch-build >> torch-build-index: << pipeline.parameters.torch-build-index >> trt-version-short: << pipeline.parameters.trt-version-short >> - cudnn-version: << pipeline.parameters.cudnn-version >> \ No newline at end of file + cudnn-version: << pipeline.parameters.cudnn-version >> diff --git a/WORKSPACE b/WORKSPACE index 41c13a3a90..7f4e951892 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -56,17 +56,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "28f352d980191492758fbeb6773eda509d3d1917c59a7f52829af1c6b521bbc7", + sha256 = "59b8b5e1954a86d50b79c13f06398d385b200da13e37a08ecf31d3c62e5ca127", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-1.14.0.dev20221114%2Bcu117.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "6e1ee2893cfd2ce3d69604d9988e0f5133aeee45da611edd07cb633f78966ce3", + sha256 = "e260fc7476be89d1650953e8643e9f7363845f5a52de4bab87ac0e619c1f6ad4", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-1.14.0.dev20221114%2Bcu117.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"], ) # Download these tarballs manually from the NVIDIA website diff --git a/core/partitioning/partitioninginfo/PartitioningInfo.h b/core/partitioning/partitioninginfo/PartitioningInfo.h index ed7d2033c6..62ad46034f 100644 --- a/core/partitioning/partitioninginfo/PartitioningInfo.h +++ b/core/partitioning/partitioninginfo/PartitioningInfo.h @@ -17,6 +17,7 @@ struct PartitioningInfo { std::vector forced_fallback_operators; bool truncate_long_and_double; ir::Device target_device; + bool cast_int8_inputs = false; std::string getGPUDeviceString() const { return "cuda:" + std::to_string(target_device.gpu_id); diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 4220764dd6..6a648f0063 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) { return nullptr; } -torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) { +torch::jit::Node* createCastNode( + SegmentedBlock& seg_block, + size_t index, + bool is_input, + at::ScalarType dtype, + std::string device, + bool force_create_node = false) { auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index]; auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index]; torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value); auto g = seg_block.g(); // if we can find upstream aten::to node, we use it's parameters for creating new cast node - if (cast_node) { + if (cast_node && !force_create_node) { std::unordered_map value_map; value_map.insert({cast_node->inputs()[0], cast_subgraph_value}); if (!is_input) { // if this value is output, we need to cast it to int32 - auto const_val = g->insertConstant(3); + auto const_val = g->insertConstant(dtype); if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) { value_map.insert({cast_node->inputs()[2], const_val}); } else { @@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i // auto cast_node = g->prependNode(g->createClone(cast_node, env)); } else { // if there is no explicit cast aten::to operation, we need to create a node - auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3); + auto const_type = g->insertConstant(dtype); auto const_zero = g->insertConstant(0); const_zero->setType(torch::jit::BoolType::get()); auto cuda = g->insertConstant(device); @@ -222,27 +228,56 @@ void getSegmentsOutputByRunning( auto target_device = partitioning_info.getGPUDeviceString(); - // auto int64 <=> int32 conversion - if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) { + // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models + if (seg_block.target() == SegmentedBlock::kTorch) { // First, check if there is Int64 input for (size_t i = 0; i < seg_block.inputs().size(); ++i) { if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); - if (t == at::kLong) { + if (t == at::kLong && partitioning_info.truncate_long_and_double) { + LOG_DEBUG( + "Detected graph Long tensor input type during shape analysis, " + << "inserting aten::to cast to Long to ensure this Torch block receives " + << "a Long-type tensor input."); // we add a cast operation to cast the type to Int64 - auto cast_node = createCastNode(seg_block, i, true, target_device); + auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device); + seg_block.g()->prependNode(cast_node); + seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); + } else if (t == at::kByte && partitioning_info.cast_int8_inputs) { + LOG_DEBUG( + "Detected graph Byte tensor input type during shape analysis, " + << "inserting aten::to cast to Byte to ensure this Torch block receives " + << "a Byte-type tensor input."); + // If the input has type Byte, ensure it is casted to the correct type + auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true); seg_block.g()->prependNode(cast_node); seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); } } } + for (size_t i = 0; i < seg_block.outputs().size(); ++i) { if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) { auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); - if (t == at::kLong) { - auto cast_node = createCastNode(seg_block, i, false, target_device); + + // If the output has type Long and truncation was requested, insert truncate + if (t == at::kLong && partitioning_info.truncate_long_and_double) { + LOG_DEBUG( + "Detected graph Long tensor output type during shape analysis, " + << "inserting aten::to cast to Int to ensure the subsequent TensorRT block " + << "receives an Int-type tensor input."); + auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device); + seg_block.g()->appendNode(cast_node); + seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); + } else if (t == at::kByte && partitioning_info.cast_int8_inputs) { + LOG_DEBUG( + "Detected graph Byte tensor output type during shape analysis, " + << "inserting aten::to cast to Int to ensure the subsequent TensorRT block " + << "receives an Int-type tensor input."); + // If the output has type Byte and casting was requested, insert Integer cast + auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true); seg_block.g()->appendNode(cast_node); seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); } @@ -254,11 +289,13 @@ void getSegmentsOutputByRunning( std::vector> input_shapes; std::vector input_types; for (size_t i = 0; i < seg_block.inputs().size(); ++i) { - if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { + auto current_input = seg_block.raw_inputs()[i]; + + if (ivalues_maps[current_input].isTensor()) { // set the input_shape and data_type // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for // shape inference - auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; + auto cur_ivalue = ivalues_maps[current_input]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { @@ -271,10 +308,16 @@ void getSegmentsOutputByRunning( cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); } + c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); if (dtype == c10::nullopt) { TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype()); + } else if (dtype && dtype.value() == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs) { + // Special case to ensure input IValues to TensorRT engine are not Int8 type if the + // model itself is not quantized + cur_ivalue = cur_ivalue.toTensor().to(at::kInt); } + if (cur_ivalue.toTensor().sizes().size() == 0) { // handle Scalar types, which has sizes of [] input_shapes.push_back(util::toVec(util::toDims(c10::List({1})))); @@ -297,6 +340,7 @@ void runShapeAnalysis( const ir::ShapeMode& shape_mode) { // register every segment's input shape, and it's running output IValues for (auto& seg_block : ctx->partitioned_blocks[block]) { + LOG_GRAPH("Running shape analysis on block " << seg_block); torch::jit::ConstantPooling(seg_block.g()); getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode); } diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index d320992a70..835faaed68 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -252,6 +252,7 @@ const std::unordered_map& get_at_trt_type_ma {at::kHalf, nvinfer1::DataType::kHALF}, {at::kInt, nvinfer1::DataType::kINT32}, {at::kChar, nvinfer1::DataType::kINT8}, + {at::kByte, nvinfer1::DataType::kINT8}, {at::kBool, nvinfer1::DataType::kBOOL}}; return at_trt_type_map; } diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 24aba31515..3de2daa14a 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -167,8 +167,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size; internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size; + internal.partitioning_info.cast_int8_inputs = true; + if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) { + internal.partitioning_info.cast_int8_inputs = false; if (external.ptq_calibrator) { internal.convert_info.engine_settings.calibrator = external.ptq_calibrator; } else { diff --git a/py/requirements.txt b/py/requirements.txt index 155cb34819..6d0916af9d 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,7 +1,7 @@ numpy pybind11==2.6.2 --extra-index-url https://download.pytorch.org/whl/nightly/cu117 -torch==1.14.0.dev20221114+cu117 -torchvision==0.15.0.dev20221114+cu117 +torch==2.0.0.dev20230103+cu117 +torchvision==0.15.0.dev20230103+cu117 --extra-index-url https://pypi.ngc.nvidia.com tensorrt==8.5.1.7 diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 489da576e2..9822f47f3b 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -300,11 +300,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } + info.partitioning_info.cast_int8_inputs = true; + if (ptq_calibrator) { info.convert_info.engine_settings.calibrator = ptq_calibrator; + info.partitioning_info.cast_int8_inputs = false; } else { if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != info.convert_info.engine_settings.enabled_precisions.end()) { + info.partitioning_info.cast_int8_inputs = false; info.lower_info.unfreeze_module = true; info.lower_info.disable_cse = true; } @@ -313,10 +317,23 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.disable_tf32 = disable_tf32; info.convert_info.engine_settings.refit = refit; info.convert_info.engine_settings.debug = debug; + + // Specify + replicate device settings for phases requiring it info.convert_info.engine_settings.device.device_type = toTRTDeviceType(device.device_type); info.convert_info.engine_settings.device.gpu_id = device.gpu_id; info.convert_info.engine_settings.device.dla_core = device.dla_core; info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback; + + info.lower_info.target_device.device_type = toTRTDeviceType(device.device_type); + info.lower_info.target_device.gpu_id = device.gpu_id; + info.lower_info.target_device.dla_core = device.dla_core; + info.lower_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback; + + info.partitioning_info.target_device.device_type = toTRTDeviceType(device.device_type); + info.partitioning_info.target_device.gpu_id = device.gpu_id; + info.partitioning_info.target_device.dla_core = device.dla_core; + info.partitioning_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback; + info.partitioning_info.enabled = torch_fallback.enabled; info.partitioning_info.min_block_size = torch_fallback.min_block_size; info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators; diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9c42963b51..01b15aa533 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3630,7 +3630,7 @@ def acc_ops_interpolate( else: layer.resize_mode = trt.ResizeMode.NEAREST - if align_corners != None: + if (align_corners is not None) and align_corners: layer.coordinate_transformation = ( trt.ResizeCoordinateTransformation.ALIGN_CORNERS ) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py index f0054e5cb7..1f4b37f4b4 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py @@ -43,6 +43,14 @@ class TestInterpolateConverter(AccTestCase): ("bilinear"), (None), ), # linear for 4D only + ( + "4d_dim_scale_bilinear_align_corners_bool", + (2, 3, 4, 5), + (None), + (2), + ("bilinear"), + (False), + ), # linear for 4D only ( "4d_dim_scale_align", (2, 3, 4, 5), diff --git a/tests/core/partitioning/test_type_auto_conversion.cpp b/tests/core/partitioning/test_type_auto_conversion.cpp index 28f620b843..03c7b70e38 100644 --- a/tests/core/partitioning/test_type_auto_conversion.cpp +++ b/tests/core/partitioning/test_type_auto_conversion.cpp @@ -107,3 +107,63 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) { } ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2)); } + +TEST(Partitioning, ExplicitNodeAutoInt8ConversionCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %y.1 : Tensor): + + %26 : int = prim::Constant[value=1]() + %21 : bool = prim::Constant[value=0]() + %60 : Device = prim::Constant[value="cuda"]() + %14 : NoneType = prim::Constant() + %3 : int = prim::Constant[value=5]() + %19 : int = prim::Constant[value=0]() + %29 : int = prim::Constant[value=2]() + %13 : int[] = prim::ListConstruct(%3, %3) + %k_.1 : Tensor = aten::ones(%13, %19, %14, %60, %14) + %20 : int[] = prim::ListConstruct(%19) + %k.1 : Tensor = aten::sum(%k_.1, %20, %21, %14) + %x.5 : Tensor = aten::add_(%x.1, %y.1, %26) + %31 : Tensor = aten::mul(%y.1, %29) + %x.9 : Tensor = aten::add_(%x.5, %31, %26) + %x.13 : Tensor = aten::add_(%x.9, %k.1, %26) + %x.17 : Tensor = aten::sub_(%x.13, %k.1, %26) + %x.21 : Tensor = aten::add_(%x.17, %k.1, %26) + %x.25 : Tensor = aten::sub_(%x.21, %k.1, %26) + + return (%x.25))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get(), true); + + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.cast_int8_inputs = true; + partitioning_info.forced_fallback_operators = {"aten::ones"}; + partitioning_info.truncate_long_and_double = true; + std::vector inputs; + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + + std::unordered_map> inputs_map; + std::unordered_map>> input_types; + inputs_map.insert({g->inputs()[0], {inputs[0]}}); + input_types.insert({g->inputs()[0], {{at::kFloat}}}); + inputs_map.insert({g->inputs()[1], {inputs[1]}}); + input_types.insert({g->inputs()[1], {{at::kInt}}}); + + partitioning_info.collection_input_spec_map = inputs_map; + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + ctx.input_types_map = input_types; + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; + + for (auto& seg_block : segmented_blocks) { + LOG_DEBUG(seg_block << " cur seg block"); + } + + // Seeking 1 inserted aten::to converting Byte to Int (%k_.1 is a Byte Tensor) + ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 1)); +} diff --git a/tests/modules/requirements.txt b/tests/modules/requirements.txt index 0b494ae33e..8ad4e47be6 100644 --- a/tests/modules/requirements.txt +++ b/tests/modules/requirements.txt @@ -1,4 +1,4 @@ timm==v0.4.12 transformers==4.17.0 --extra-index-url https://download.pytorch.org/whl/nightly/cu117 -torchvision==0.15.0.dev20221114+cu117 +torchvision==0.15.0.dev20230103+cu117 diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel index 2ea1c8cef4..608efac78b 100644 --- a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel @@ -56,17 +56,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "28f352d980191492758fbeb6773eda509d3d1917c59a7f52829af1c6b521bbc7", + sha256 = "59b8b5e1954a86d50b79c13f06398d385b200da13e37a08ecf31d3c62e5ca127", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-1.14.0.dev20221114%2Bcu117.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "6e1ee2893cfd2ce3d69604d9988e0f5133aeee45da611edd07cb633f78966ce3", + sha256 = "e260fc7476be89d1650953e8643e9f7363845f5a52de4bab87ac0e619c1f6ad4", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-1.14.0.dev20221114%2Bcu117.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"], ) #################################################################################### diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu index 6af670ec20..da3f766240 100644 --- a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu @@ -56,17 +56,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "28f352d980191492758fbeb6773eda509d3d1917c59a7f52829af1c6b521bbc7", + sha256 = "59b8b5e1954a86d50b79c13f06398d385b200da13e37a08ecf31d3c62e5ca127", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-1.14.0.dev20221114%2Bcu117.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "6e1ee2893cfd2ce3d69604d9988e0f5133aeee45da611edd07cb633f78966ce3", + sha256 = "e260fc7476be89d1650953e8643e9f7363845f5a52de4bab87ac0e619c1f6ad4", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-1.14.0.dev20221114%2Bcu117.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"], ) #################################################################################### diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 78a3be5dd5..00ddbabd22 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -17,8 +17,9 @@ # Importing supported Backends import torch import torch_tensorrt as torchtrt -from torch_tensorrt.fx.lower import compile -from torch_tensorrt.fx.utils import LowerPrecision + +# from torch_tensorrt.fx.lower import compile +# from torch_tensorrt.fx.utils import LowerPrecision import tensorrt as trt from utils import ( @@ -134,21 +135,17 @@ def run_torch_tensorrt( # Runs inference using FX2TRT backend def run_fx2trt(model, input_tensors, params, precision, batch_size): print("Running FX2TRT for precision: ", precision, " batch_size : ", batch_size) - if precision == "fp32": - precision = LowerPrecision.FP32 - elif precision == "fp16": - precision = LowerPrecision.FP16 + if precision == "fp16": model.half() input_tensors = [tensor.half() for tensor in input_tensors] + # Run lowering eager mode benchmark start_compile = time.time_ns() - model = compile( + model = torchtrt.compile( model, - input_tensors, - max_batch_size=batch_size, - lower_precision=precision, - verbose_log=False, - explicit_batch_dimension=True, + ir="fx", + inputs=input_tensors, + enabled_precisions={torch.float16 if precision == "fp16" else torch.float32}, ) end_compile = time.time_ns() compile_time_ms = (end_compile - start_compile) / 1e6 @@ -174,6 +171,69 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size): recordStats("FX-TensorRT", timings, precision, batch_size, compile_time_ms) +def run_dynamo(model, input_tensors, params, precision, batch_size): + dynamo_backend = params["dynamo_backend"] + print( + "Running Dynamo with backend: ", + dynamo_backend, + " for precision: ", + precision, + " batch_size : ", + batch_size, + ) + + if precision == "fp16": + input_tensors = [tensor.half() for tensor in input_tensors] + + fp16_mode = True if precision == "fp16" else False + # dynamo_backend_params = {"fp16_mode" : fp16_mode} + # model = torch.compile( + # model, + # mode="default", + # dynamic=False, + # fullgraph=False, + # backend=dynamo_backend, + # # **dynamo_backend_params + # ) + import torch._dynamo as dynamo + + model = dynamo.optimize(dynamo_backend, nopython=True)(model) + # Compile and measure the time + with torch.no_grad(): + start_compile = time.time_ns() + features = model(*input_tensors) + end_compile = time.time_ns() + compile_time_ms = (end_compile - start_compile) / 1e6 + iters = params.get("iterations", 20) + # import pdb; pdb.set_trace() + print("============= DONE 0 ==================") + + print("============= DONE 1 ==================") + # Warm up + model = torch._dynamo.run(model) + # import pdb; pdb.set_trace() + + exported_model, _ = torch._dynamo.export(model, *input_tensors) + for i in range(WARMUP_ITER): + print("==== ITER: ", i) + features = exported_model(*input_tensors) + + torch.cuda.synchronize() + print("============= DONE 2 ==================") + timings = [] + for i in range(iters): + start_time = timeit.default_timer() + features = exported_model(*input_tensors) + torch.cuda.synchronize() + end_time = timeit.default_timer() + meas_time = end_time - start_time + timings.append(meas_time) + + recordStats( + "Dynamo-" + dynamo_backend, timings, precision, batch_size, compile_time_ms + ) + + def torch_dtype_from_trt(dtype): if dtype == trt.int8: return torch.int8 @@ -326,6 +386,7 @@ def run( batch_size, ) run_fx2trt(model_torch, input_tensors, params, precision, batch_size) + run_dynamo(model_torch, input_tensors, params, precision, batch_size) elif backend == "torchscript": run_torch(model, input_tensors, params, precision, batch_size) @@ -373,6 +434,8 @@ def run( is_trt_engine, batch_size, ) + elif backend == "dynamo": + run_dynamo(model_torch, input_tensors, params, precision, batch_size) # Generate report @@ -499,6 +562,12 @@ def load_torch_model(params): action="store_true", help="Boolean flag to determine if the user provided model is a TRT engine or not", ) + arg_parser.add_argument( + "--dynamo_backend", + type=str, + default="fx2trt", + help="List of backends to use in Torchdynamo. Select options: inductor|fx2trt", + ) arg_parser.add_argument( "--report", type=str, @@ -602,6 +671,11 @@ def load_torch_model(params): ) backends = parse_backends(params["backends"]) + if "dynamo" in backends and (model_torch is None): + raise ValueError( + "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument" + ) + truncate_long_and_double = params["truncate"] batch_size = params["batch_size"] is_trt_engine = params["is_trt_engine"] @@ -611,9 +685,11 @@ def load_torch_model(params): input_tensors = parse_inputs( params["inputs"], precision_to_dtype(precision) ) + if not is_trt_engine and (precision == "fp16" or precision == "half"): # If model is TensorRT serialized engine then model.half will report failure model = model.half() + status = run( model, backends,