diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 910b8f7d6d..c569a6088e 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -149,8 +149,26 @@ auto select_registrations TORCHTRT_UNUSED = // IShuffleLayer removes redundant dimensions auto shuffle_layer = ctx->net->addShuffle(*out); TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions( - util::squeezeDims(out->getDimensions(), dim, !ctx->input_is_dynamic)); + + auto num_zero_dimensions = + util::validateInputDimsForShuffle(out->getDimensions(), ctx->input_is_dynamic); + TORCHTRT_CHECK( + num_zero_dimensions >= 0, + "Detected multiple zero dimensions and dynamic shape in aten::select, " + << "which is not currently supported in TensorRT"); + + // If the input is not dynamic, and the tensor is empty (has some dimension 0) + // Then 0 is no longer a placeholder for inherited dimensions + if (!ctx->input_is_dynamic && (num_zero_dimensions > 0)) { + LOG_DEBUG("Setting zero as a true dimension (not placeholder) in aten::select"); + shuffle_layer->setZeroIsPlaceholder(false); + } + + shuffle_layer->setReshapeDimensions(util::squeezeDims( + out->getDimensions(), + dim, + ctx->input_is_dynamic, + ctx->input_is_dynamic && (num_zero_dimensions > 0))); shuffle_layer->setName(util::node_info(n).c_str()); out = shuffle_layer->getOutput(0); } diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 85ea449672..7982ffb846 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -180,7 +180,25 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use return dims; } -nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros) { +int validateInputDimsForShuffle(const nvinfer1::Dims& d, bool input_is_dynamic) { + int num_zeros_detected = 0; + + // For each dimension, increment counter if that dimension has value 0 + for (int i = 0; i < d.nbDims; i++) { + if (d.d[i] == 0) { + num_zeros_detected++; + } + } + + // If the tensor from which the dimensions originate has dynamic shape and more than 1 + // zero dimension is detected, this constitutes an invalid shape to the TRT Shuffle Layer, + // since dynamic dimensions to Shuffle Layers are generally represented with a 0 + // denoting to inherit the dimension from the input tensor, thus causing an + // overload of the "0" dimension + return (input_is_dynamic && num_zeros_detected > 1) ? -1 : num_zeros_detected; +} + +nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros, bool swap_existing_zeros) { // acceptable range for pos is [0, d.nbDims] TORCHTRT_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to squeeze is out of bounds."); @@ -188,7 +206,24 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros) { int j = 0; for (int i = 0; i < d.nbDims; i++) { if (i != pos) { - dims.d[j++] = (use_zeros && d.d[i] == -1) ? 0 : d.d[i]; + // If zeros are replacing dynamic/existing dimensions, + // Replace all instances of -1, indicating dynamic dimension + // with 0, indicating copy the dimension from another tensor + // (Generally used for reshape operations) + if (use_zeros && d.d[i] == -1) { + dims.d[j] = 0; + // If zeros already exist in the dimensions (empty tensor), + // Replace all instances of 0, indicating empty dimension + // with -1, indicating inherit the dimension from reshape + // (Generally used for reshape operations) + } else if (swap_existing_zeros && d.d[i] == 0) { + dims.d[j] = -1; + // Otherwise, replace the dimension with the same value from the input + } else { + dims.d[j] = d.d[i]; + } + + j++; } } dims.nbDims = j; diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 355b0d13cc..a09407a5cd 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -135,8 +135,9 @@ nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to); nvinfer1::Dims toDimsTailPad(c10::IntArrayRef l, uint64_t pad_to); nvinfer1::Dims toDimsTailPad(c10::List l, uint64_t pad_to); nvinfer1::Dims unpadDims(const nvinfer1::Dims& d); +int validateInputDimsForShuffle(const nvinfer1::Dims& d, bool input_is_dynamic); nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val = 1, bool use_zeros = true); -nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true); +nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true, bool swap_existing_zeros = false); nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims = true); nvinfer1::Dims toDims(c10::IntArrayRef l); nvinfer1::Dims toDims(c10::List l); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index e5576f0109..991a1b792c 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -140,6 +140,31 @@ TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=0]() + %4 : Tensor = aten::select(%0, %3, %2) + return (%4))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = torch::ones({2, 20, 0, 768}).to(at::kCUDA); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::sameShape(jit_results[0], trt_results[0])); +} + TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): diff --git a/tests/util/util.cpp b/tests/util/util.cpp index 4f87a8a0f8..0b3e0058c9 100644 --- a/tests/util/util.cpp +++ b/tests/util/util.cpp @@ -28,6 +28,20 @@ bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, return result <= threshold; } +bool sameShape(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor) { + if (computed_tensor.sizes().size() != gt_tensor.sizes().size()) { + return false; + } + + for (size_t i = 0; i < computed_tensor.sizes().size(); i++) { + if (computed_tensor.sizes()[i] != gt_tensor.sizes()[i]) { + return false; + } + } + + return true; +} + bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float threshold) { torch::Tensor cosine_sim = torch::nn::functional::cosine_similarity( computed_tensor.flatten(), gt_tensor.flatten(), torch::nn::functional::CosineSimilarityFuncOptions().dim(0)); diff --git a/tests/util/util.h b/tests/util/util.h index 0ad7e0d5bc..7b1e46e083 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -21,6 +21,8 @@ bool cosineSimEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tens bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = ATOL, float rtol = RTOL); +bool sameShape(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor); + bool exactlyEqual(const at::Tensor& a, const at::Tensor& b); void pointwise_test_helper(