diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index b88ebc89fa..6b22fea8d4 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -393,6 +393,89 @@ auto expand_registrations TORCHTRT_UNUSED = auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}) + .pattern( + {"aten::meshgrid(Tensor[] tensors) -> (Tensor[])", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // torch.meshgrid only supports 1D or 0D input tensors + auto arg_tensors = args[0].IValue()->toListRef(); + std::vector tensors; + for (auto t : arg_tensors) { + if (t.isTensor()) { + auto torch_tensor = t.toTensor(); + tensors.push_back(tensor_to_const(ctx, torch_tensor)); + } else { + auto cont = t.toCustomClass(); + tensors.push_back(cont->tensor()); + } + } + + // build the output shape for all tensors in the output list + nvinfer1::Dims output_dims; + output_dims.nbDims = tensors.size(); + for (size_t idx = 0UL; idx < tensors.size(); ++idx) { + auto dims = tensors[idx]->getDimensions(); + output_dims.d[idx] = dims.nbDims == 0 ? 1 : dims.d[0]; + } + std::vector out_tensors; + // Reshape tensors into output shape (reshape, expand) + for (size_t idx = 0UL; idx < tensors.size(); ++idx) { + auto t = tensors[idx]; + auto dims = t->getDimensions(); + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = tensors.size(); + for (size_t reshape_idx = 0UL; reshape_idx < tensors.size(); ++reshape_idx) { + if (reshape_idx == idx) { + reshape_dims.d[reshape_idx] = dims.nbDims == 0 ? 1 : dims.d[0]; + } else { + reshape_dims.d[reshape_idx] = 1; + } + } + // Add a reshape layer before expanding dims + auto reshape_layer = ctx->net->addShuffle(*t); + reshape_layer->setReshapeDimensions(reshape_dims); + std::stringstream reshape_layer_name; + reshape_layer_name << util::node_info(n) << "_meshgrid_reshape_" << std::to_string(idx); + reshape_layer->setName(reshape_layer_name.str().c_str()); + auto reshaped = reshape_layer->getOutput(0); + LOG_DEBUG("Tensor " << idx << " reshaped to : " << reshaped->getDimensions() << " from " << dims); + + // Add slice layer for expansion + std::vector start_vec(output_dims.nbDims, 0); + auto start_offset = util::toDims(c10::IntArrayRef(start_vec)); + + std::vector strides_vec(output_dims.nbDims, 0); + for (int64_t i = 0; i < output_dims.nbDims; i++) { + strides_vec[i] = (reshaped->getDimensions().d[i] != 1); + } + + auto strides = util::toDims(c10::IntArrayRef(strides_vec)); + + auto slice_layer = ctx->net->addSlice(*reshaped, start_offset, output_dims, strides); + std::stringstream slice_layer_name; + slice_layer_name << util::node_info(n) << "_meshgrid_slice_" << std::to_string(idx); + slice_layer->setName(slice_layer_name.str().c_str()); + auto slice_output = slice_layer->getOutput(0); + LOG_DEBUG("Tensor " << idx << " expanded to : " << slice_output->getDimensions()); + out_tensors.push_back(slice_output); + } + + // Pack output tensors into list + c10::ListTypePtr lt = n->output()->type()->expect(); + c10::TypePtr elementType = lt->getElementType(); + auto list = c10::impl::GenericList(elementType); + list.reserve(out_tensors.size()); + + for (auto t : out_tensors) { + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(t); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(ival); + } + + auto output_list = std::move(torch::jit::IValue(list)); + ctx->AssociateValueAndIValue(n->outputs()[0], output_list); return true; }}); diff --git a/tests/core/conversion/converters/test_expand.cpp b/tests/core/conversion/converters/test_expand.cpp index 53630b661a..77b42fb1d9 100644 --- a/tests/core/conversion/converters/test_expand.cpp +++ b/tests/core/conversion/converters/test_expand.cpp @@ -669,3 +669,38 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } + +TEST(Converters, ATenMeshGridConvertsCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor, %y : Tensor, %z : Tensor): + %0 : Tensor[] = prim::ListConstruct(%x, %y, %z) + %1 : Tensor[] = aten::meshgrid(%0) + %x_0 : Tensor, %y_0 : Tensor, %z_0 : Tensor = prim::ListUnpack(%1) + return (%x_0, %y_0, %z_0))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto x = at::randint(1, 10, {2}, {at::kCUDA}).to(torch::kInt); + auto jit_x = at::clone(x); + + auto y = at::randint(1, 10, {5}, {at::kCUDA}).to(torch::kInt); + auto jit_y = at::clone(y); + + auto z = torch::tensor(22, {at::kCUDA}).to(torch::kInt); // 0D + auto jit_z = at::clone(z); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_x, jit_y, jit_z}); + + auto trt_x = at::clone(jit_x); + auto trt_y = at::clone(jit_y); + auto trt_z = at::clone(jit_z); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_x, trt_y, trt_z}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[2], trt_results[2], 2e-6)); +}