Skip to content

Commit 7e91ee0

Browse files
committed
Standardize domain type to be int64 vector
- Resolve C++ failing test case - Add test case for C++ API
1 parent f645343 commit 7e91ee0

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ struct Input : torch::CustomClassHolder {
464464
*/
465465
TORCHTRT_API Input(
466466
c10::ArrayRef<int64_t> shape,
467-
c10::ArrayRef<int64_t> tensor_domain,
467+
std::vector<int64_t> tensor_domain,
468468
TensorFormat format = TensorFormat::kContiguous);
469469

470470
/**
@@ -493,7 +493,7 @@ struct Input : torch::CustomClassHolder {
493493
TORCHTRT_API Input(
494494
c10::ArrayRef<int64_t> shape,
495495
DataType dtype,
496-
c10::ArrayRef<int64_t> tensor_domain,
496+
std::vector<int64_t> tensor_domain,
497497
TensorFormat format = TensorFormat::kContiguous);
498498

499499
/**
@@ -604,7 +604,7 @@ struct Input : torch::CustomClassHolder {
604604
c10::ArrayRef<int64_t> min_shape,
605605
c10::ArrayRef<int64_t> opt_shape,
606606
c10::ArrayRef<int64_t> max_shape,
607-
c10::ArrayRef<int64_t> tensor_domain,
607+
std::vector<int64_t> tensor_domain,
608608
TensorFormat format = TensorFormat::kContiguous);
609609

610610
/**
@@ -644,7 +644,7 @@ struct Input : torch::CustomClassHolder {
644644
c10::ArrayRef<int64_t> opt_shape,
645645
c10::ArrayRef<int64_t> max_shape,
646646
DataType dtype,
647-
c10::ArrayRef<int64_t> tensor_domain,
647+
std::vector<int64_t> tensor_domain,
648648
TensorFormat format = TensorFormat::kContiguous);
649649

650650
/**

cpp/src/types.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ Input::Input(c10::IntArrayRef shape, TensorFormat format) {
198198
this->tensor_domain = std::vector<int64_t>{0, 2};
199199
}
200200

201-
Input::Input(c10::IntArrayRef shape, c10::IntArrayRef tensor_domain, TensorFormat format) {
201+
Input::Input(c10::IntArrayRef shape, std::vector<int64_t> tensor_domain, TensorFormat format) {
202202
this->opt_shape = torch_tensorrt::core::util::toVec(shape);
203203
this->min_shape = torch_tensorrt::core::util::toVec(shape);
204204
this->max_shape = torch_tensorrt::core::util::toVec(shape);
205205
this->shape = torch_tensorrt::core::util::toVec(shape);
206206
this->dtype = DataType::kUnknown;
207207
this->format = format;
208208
this->input_is_dynamic = false;
209-
this->tensor_domain = torch_tensorrt::core::util::toVec(tensor_domain);
209+
this->tensor_domain = tensor_domain;
210210
}
211211

212212
Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat format) {
@@ -220,15 +220,15 @@ Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat format) {
220220
this->tensor_domain = std::vector<int64_t>{0, 2};
221221
}
222222

223-
Input::Input(c10::IntArrayRef shape, DataType dtype, c10::IntArrayRef tensor_domain, TensorFormat format) {
223+
Input::Input(c10::IntArrayRef shape, DataType dtype, std::vector<int64_t> tensor_domain, TensorFormat format) {
224224
this->opt_shape = torch_tensorrt::core::util::toVec(shape);
225225
this->min_shape = torch_tensorrt::core::util::toVec(shape);
226226
this->max_shape = torch_tensorrt::core::util::toVec(shape);
227227
this->shape = torch_tensorrt::core::util::toVec(shape);
228228
this->dtype = dtype;
229229
this->format = format;
230230
this->input_is_dynamic = false;
231-
this->tensor_domain = torch_tensorrt::core::util::toVec(tensor_domain);
231+
this->tensor_domain = tensor_domain;
232232
}
233233

234234
Input::Input(
@@ -315,7 +315,7 @@ Input::Input(
315315
c10::IntArrayRef min_shape,
316316
c10::IntArrayRef opt_shape,
317317
c10::IntArrayRef max_shape,
318-
c10::IntArrayRef tensor_domain,
318+
std::vector<int64_t> tensor_domain,
319319
TensorFormat format) {
320320
this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape);
321321
this->min_shape = torch_tensorrt::core::util::toVec(min_shape);
@@ -325,7 +325,7 @@ Input::Input(
325325
this->dtype = DataType::kUnknown;
326326
this->format = format;
327327
this->input_is_dynamic = true;
328-
this->tensor_domain = torch_tensorrt::core::util::toVec(tensor_domain);
328+
this->tensor_domain = tensor_domain;
329329
}
330330

331331
Input::Input(
@@ -350,7 +350,7 @@ Input::Input(
350350
c10::IntArrayRef opt_shape,
351351
c10::IntArrayRef max_shape,
352352
DataType dtype,
353-
c10::IntArrayRef tensor_domain,
353+
std::vector<int64_t> tensor_domain,
354354
TensorFormat format) {
355355
this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape);
356356
this->min_shape = torch_tensorrt::core::util::toVec(min_shape);
@@ -360,7 +360,7 @@ Input::Input(
360360
this->dtype = dtype;
361361
this->format = format;
362362
this->input_is_dynamic = true;
363-
this->tensor_domain = torch_tensorrt::core::util::toVec(tensor_domain);
363+
this->tensor_domain = tensor_domain;
364364
}
365365

366366
Input::Input(at::Tensor tensor) {

tests/cpp/test_collections.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,50 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
4545
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor()));
4646
}
4747

48+
TEST(CppAPITests, TestSpecifyDomainStandardTensorInput) {
49+
std::string path = "tests/modules/tuple_input_scripted.jit.pt";
50+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
51+
52+
torch::jit::Module mod;
53+
try {
54+
// Deserialize the ScriptModule from a file using torch::jit::load().
55+
mod = torch::jit::load(path);
56+
} catch (const c10::Error& e) {
57+
std::cerr << "error loading the model\n";
58+
}
59+
mod.eval();
60+
mod.to(torch::kCUDA);
61+
62+
std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
63+
std::tuple<torch::jit::IValue, torch::jit::IValue> input_tuple(in0, in0);
64+
65+
complex_inputs.push_back(input_tuple);
66+
67+
auto out = mod.forward(complex_inputs);
68+
69+
// Specify input tensor domain argument
70+
auto input_shape = torch_tensorrt::Input(in0.sizes(), std::vector<int64_t>{35, 377}, torch_tensorrt::DataType::kHalf);
71+
72+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
73+
74+
std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple(input_shape_ivalue, input_shape_ivalue);
75+
76+
torch::jit::IValue complex_input_shape(input_shape_tuple);
77+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
78+
torch::jit::IValue complex_input_shape2(input_tuple2);
79+
80+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
81+
compile_settings.min_block_size = 1;
82+
83+
// // FP16 execution
84+
compile_settings.enabled_precisions = {torch::kHalf};
85+
// // Compile module
86+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
87+
auto trt_out = trt_mod.forward(complex_inputs);
88+
89+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor()));
90+
}
91+
4892
TEST(CppAPITests, TestCollectionTupleInput) {
4993
std::string path = "tests/modules/tuple_input_scripted.jit.pt";
5094
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);

0 commit comments

Comments
 (0)