Skip to content

Commit 57f953e

Browse files
authored
Merge pull request #15 from bddppq/misc-fixes
Fix indeterministic params order in `RunGraph`
2 parents 60ee9c6 + 86600da commit 57f953e

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

tests/core/converters/test_linear.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
TEST(Converters, ATenLinearNoBiasConvertsCorrectly) {
99
const auto graph = R"IR(
1010
graph(%0 : Tensor,
11-
%1 : Float(2, 2)):
11+
%1 : Float(3, 2)):
1212
%2 : None = prim::Constant()
1313
%3 : Tensor = aten::linear(%0, %1, %2)
1414
return (%3))IR";
@@ -19,7 +19,7 @@ TEST(Converters, ATenLinearNoBiasConvertsCorrectly) {
1919
//Input Tensor needs to be 4D for TensorRT linear
2020
auto in = at::randint(1, 10, {1, 2}, {at::kCUDA});
2121
auto w = at::randint(1, 10, {3, 2}, {at::kCUDA});
22-
22+
2323
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {w});
2424
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
2525

@@ -32,8 +32,6 @@ TEST(Converters, ATenLinearNoBiasConvertsCorrectly) {
3232
}
3333

3434

35-
//TODO: Track down the cause of why the JIT linear function fails
36-
//TODO: Sort out what the exepected output dim should be?
3735
TEST(Converters, ATenLinearBiasConvertsCorrectly) {
3836
const auto graph = R"IR(
3937
graph(%0 : Tensor,
@@ -53,7 +51,7 @@ TEST(Converters, ATenLinearBiasConvertsCorrectly) {
5351
auto jit_in = at::clone(in);
5452
auto jit_w = at::clone(w);
5553
auto jit_b = at::clone(b);
56-
54+
5755
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
5856
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
5957

tests/models/BUILD

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ cc_test(
1515
name = "test_multiple_registered_engines",
1616
srcs = ["test_multiple_registered_engines.cpp"],
1717
deps = [
18-
"//cpp/api:trtorch",
19-
"//tests/util",
20-
"@libtorch//:libtorch",
21-
"@googletest//:gtest_main",
18+
":module_test",
2219
],
2320
data = [
2421
":jit_models"

tests/models/test_multiple_registered_engines.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <string>
22
#include "gtest/gtest.h"
3-
#include "torch/csrc/jit/irparser.h"
3+
#include "torch/script.h"
44
#include "tests/util/util.h"
5-
#include "cpp/trtorch.h"
5+
#include "trtorch/trtorch.h"
66

77
TEST(ModuleTests, CanRunMultipleEngines) {
88
torch::jit::script::Module mod1;
@@ -16,7 +16,7 @@ TEST(ModuleTests, CanRunMultipleEngines) {
1616
return;
1717
}
1818

19-
const std::vector<int64_t> input_shape = {1,3,224,224};
19+
const std::vector<std::vector<int64_t>> input_shapes = {{1,3,224,224}};
2020

2121
std::vector<torch::jit::IValue> jit1_inputs_ivalues;
2222
std::vector<torch::jit::IValue> trt1_inputs_ivalues;
@@ -38,18 +38,18 @@ TEST(ModuleTests, CanRunMultipleEngines) {
3838
std::vector<at::Tensor> jit1_results;
3939
jit1_results.push_back(jit1_results_ivalues.toTensor());
4040

41-
torch::jit::IValue jit2_results_ivalues = trtorch::tests::util::RunModuleForward(mod2, jit2_inputs_ivalues);
41+
torch::jit::IValue jit2_results_ivalues = trtorch::tests::util::RunModuleForward(mod2, jit2_inputs_ivalues);
4242
std::vector<at::Tensor> jit2_results;
4343
jit2_results.push_back(jit2_results_ivalues.toTensor());
4444

4545

4646
auto trt_mod1 = trtorch::CompileGraph(mod1, input_shapes);
47-
torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward(trt1_mod, trt1_inputs_ivalues);
47+
torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod1, trt1_inputs_ivalues);
4848
std::vector<at::Tensor> trt1_results;
4949
trt1_results.push_back(trt1_results_ivalues.toTensor());
5050

5151
auto trt_mod2 = trtorch::CompileGraph(mod2, input_shapes);
52-
torch::jit::IValue trt2_results_ivalues = trtorch::tests::util::RunModuleForward(trt2_mod, trt2_inputs_ivalues);
52+
torch::jit::IValue trt2_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod2, trt2_inputs_ivalues);
5353
std::vector<at::Tensor> trt2_results;
5454
trt2_results.push_back(trt2_results_ivalues.toTensor());
5555

tests/util/run_graph.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ std::vector<at::Tensor> RunGraph(std::shared_ptr<torch::jit::Graph>& g,
2121
inputs_.push_back(in.clone());
2222
}
2323

24-
for (auto p : params) {
25-
inputs_.push_back(p.second.clone());
24+
for (auto* in : g->inputs()) {
25+
const auto iter = params.find(in);
26+
if (iter != params.end()) {
27+
inputs_.push_back(iter->second.clone());
28+
}
2629
}
2730

2831
torch::jit::GraphExecutor executor(g);

0 commit comments

Comments
 (0)