Skip to content

Commit 809f9b3

Browse files
authored
Merge pull request #166 from NVIDIA/torch_1.6.0_update
Torch 1.6.0 update
2 parents 07e739e + 253f55a commit 809f9b3

30 files changed

+352
-90
lines changed

.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.3.1
1+
3.4.1

WORKSPACE

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ http_archive(
5353
name = "libtorch",
5454
build_file = "@//third_party/libtorch:BUILD",
5555
strip_prefix = "libtorch",
56-
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.5.1.zip"],
57-
sha256 = "cf0691493d05062fe3239cf76773bae4c5124f4b039050dbdd291c652af3ab2a"
56+
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.6.0.zip"],
57+
sha256 = "fded948bd2dbee625cee33ebbd4843a69496729389e0200a90fbb667cdaeeb69"
5858
)
5959

6060
http_archive(
6161
name = "libtorch_pre_cxx11_abi",
6262
build_file = "@//third_party/libtorch:BUILD",
6363
strip_prefix = "libtorch",
64-
sha256 = "818977576572eadaf62c80434a25afe44dbaa32ebda3a0919e389dcbe74f8656",
65-
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-shared-with-deps-1.5.1.zip"],
64+
sha256 = "141bb229f4bbf905541096cf8705785e7b0c79e37ca1e5db9d372730b1b9abd7",
65+
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-shared-with-deps-1.6.0.zip"],
6666
)
6767

6868
# Download these tarballs manually from the NVIDIA website

core/conversion/InterfaceTypes.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace core {
88
namespace conversion {
99

1010
GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs,
11-
std::vector<at::Tensor> params) {
11+
std::vector<torch::jit::IValue> params) {
1212
GraphParams named_params;
1313
auto param_it = params.begin();
1414
for (auto in : inputs) {
@@ -18,10 +18,8 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs,
1818
++param_it;
1919
}
2020
}
21-
//ASSERT(named_params.size() == params.size);
22-
if (named_params.size() != params.size()) {
23-
LOG_ERROR("Graph parameter parsing failed");
24-
}
21+
22+
TRTORCH_CHECK(named_params.size() == params.size(), "Graph parameter parsing failed, mismatched number of static parameters and IValues")
2523
return std::move(named_params);
2624
}
2725

core/conversion/conversion.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,11 @@ void AddInputs(ConversionCtx* ctx,
165165
TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)");
166166

167167
ctx->cfg->addOptimizationProfile(profile);
168-
// TODO: Enable in TRT 7.1
169-
// if (ctx->op_precision == nvinfer1::DataType::kINT8) {
170-
// ctx->cfg->setCalibrationProfile(profile);
171-
// }
168+
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1)
169+
if (ctx->op_precision == nvinfer1::DataType::kINT8) {
170+
ctx->cfg->setCalibrationProfile(profile);
171+
}
172+
#endif
172173
}
173174

174175
void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
@@ -186,7 +187,7 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
186187

187188
void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
188189
for (auto p : params) {
189-
ctx->evaluated_value_map[p.first] = torch::jit::IValue(p.second.clone());
190+
ctx->evaluated_value_map[p.first] = std::move(p.second);
190191
}
191192
}
192193

core/conversion/conversion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ struct ConversionInfo {
3030
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
3131
};
3232

33-
using GraphParams = std::map<torch::jit::Value*, at::Tensor>;
33+
//TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
3434

35-
GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vector<at::Tensor> params);
35+
using GraphParams = std::map<torch::jit::Value*, torch::jit::IValue>;
36+
37+
GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vector<torch::jit::IValue> params);
3638

3739
// Converts a already lowered block (blocks with no sub blocks) to
3840
// a serialized TensorRT engine that can be deserialized and run

core/conversion/evaluators/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ cc_library(
1616
"NodeEvaluatorRegistry.cpp",
1717
"prim.cpp",
1818
"aten.cpp",
19-
"eval_macros.h"
19+
"eval_macros.h",
20+
"eval_util.h",
21+
"eval_util.cpp"
2022
],
2123
deps = [
2224
"//core/util:prelude",
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include "ATen/core/ivalue.h"
2+
#include "ATen/core/List.h"
3+
#include "core/util/prelude.h"
4+
#include "ATen/core/functional.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace conversion {
9+
namespace evaluators {
10+
11+
//TODO: Switch back to PyTorch canonical implimentation
12+
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
13+
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {
14+
return c10::nullopt;
15+
}
16+
const torch::jit::Node* node = v->node();
17+
const c10::TypePtr& type = v->type();
18+
if (type->isSubtypeOf(c10::TensorType::get())) {
19+
return node->t(c10::attr::value);
20+
} else if (type->isSubtypeOf(c10::BoolType::get())) {
21+
return (bool)node->i(c10::attr::value);
22+
} else if (
23+
type->isSubtypeOf(c10::NumberType::get()) &&
24+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::i) {
25+
return node->i(c10::attr::value);
26+
} else if (
27+
type->isSubtypeOf(c10::NumberType::get()) &&
28+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::f) {
29+
return node->f(c10::attr::value);
30+
} else if (type->isSubtypeOf(c10::ListType::ofInts())) {
31+
try {
32+
const auto& is = node->is(c10::attr::value);
33+
return is;
34+
} catch (const std::exception& ex) {
35+
const auto& ival = node->ival(c10::attr::value);
36+
return ival;
37+
}
38+
} else if (type->isSubtypeOf(c10::ListType::ofFloats())) {
39+
try {
40+
const auto& fs = node->fs(c10::attr::value);
41+
return fs;
42+
} catch (const std::exception& ex) {
43+
const auto& ival = node->ival(c10::attr::value);
44+
return ival;
45+
}
46+
} else if (type->isSubtypeOf(c10::ListType::ofBools())) {
47+
const auto bs = c10::fmap<bool>(node->is(c10::attr::value));
48+
return bs;
49+
} else if (type->isSubtypeOf(c10::ListType::ofTensors())) {
50+
try {
51+
const auto& ts = node->ts(c10::attr::value);
52+
return ts;
53+
} catch (const std::exception& ex) {
54+
const auto& ival = node->ival(c10::attr::value);
55+
return ival;
56+
}
57+
} else if (type->isSubtypeOf(c10::ListType::ofStrings())) {
58+
try {
59+
const auto& ss = node->ss(c10::attr::value);
60+
auto vals = c10::impl::GenericList(c10::StringType::get());
61+
for (const auto& str : ss) {
62+
vals.push_back(str);
63+
}
64+
return vals;
65+
} catch (const std::exception& ex) {
66+
const auto& ival = node->ival(c10::attr::value);
67+
return ival;
68+
}
69+
} else if (
70+
type->cast<c10::ListType>() &&
71+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
72+
const auto& list = node->ival(c10::attr::value);
73+
TRTORCH_ASSERT(list.isList(), "Is not a list");
74+
return list;
75+
} else if (
76+
type->cast<c10::DictType>() &&
77+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
78+
const auto& dict = node->ival(c10::attr::value);
79+
TRTORCH_ASSERT(dict.isGenericDict(), "Is not a dict");
80+
return dict;
81+
} else if (
82+
type->cast<c10::TupleType>() &&
83+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
84+
const auto& tup = node->ival(c10::attr::value);
85+
TRTORCH_ASSERT(tup.isTuple(), "Is not a tuple");
86+
return tup;
87+
} else if (type == c10::StringType::get()) {
88+
const auto& s = node->s(c10::attr::value);
89+
return s;
90+
} else if (type == c10::DeviceObjType::get()) {
91+
auto d = c10::Device(node->s(c10::attr::value));
92+
return d;
93+
} else if (node->mustBeNone()) {
94+
return torch::jit::IValue();
95+
} else {
96+
std::stringstream ss;
97+
ss << "constant literal not supported for: " << type->str();
98+
throw std::runtime_error(ss.str());
99+
}
100+
}
101+
102+
} // namespace evaluators
103+
} // namespace conversion
104+
} // namespace core
105+
} // namespace trtorch
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include "torch/csrc/jit/ir/ir.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace evaluators {
9+
10+
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
11+
12+
} // namespace evaluators
13+
} // namespace conversion
14+
} // namespace core
15+
} // namespace trtorch

core/conversion/evaluators/prim.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <limits>
22

33
#include "torch/csrc/jit/ir/ir.h"
4-
#include "torch/csrc/jit/ir/constants.h"
4+
//#include "torch/csrc/jit/ir/constants.h"
55
#include "ATen/core/functional.h"
66
#include "ATen/core/ivalue.h"
77
#include "ATen/core/List.h"
@@ -11,6 +11,7 @@
1111

1212
#include "core/conversion/evaluators/evaluators.h"
1313
#include "core/conversion/evaluators/eval_macros.h"
14+
#include "core/conversion/evaluators/eval_util.h"
1415

1516
namespace trtorch {
1617
namespace core {
@@ -25,7 +26,7 @@ auto prim_registrations = RegisterNodeEvaluators()
2526
if (n->output()->type()->kind() == at::FunctionType::Kind) {
2627
return {};
2728
}
28-
return torch::jit::toIValue(n->output());
29+
return evaluators::toIValue(n->output());
2930
}
3031
}).evaluator({
3132
torch::jit::prim::NumToTensor,

core/lowering/lowering.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "torch/csrc/jit/passes/lower_graph.h"
88
#include "torch/csrc/jit/passes/lower_tuples.h"
99
#include "torch/csrc/jit/passes/peephole.h"
10-
#include "torch/csrc/jit/passes/quantization.h"
1110

1211
#include "core/util/prelude.h"
1312
#include "core/lowering/lowering.h"
@@ -50,8 +49,7 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
5049
return mod_;
5150
}
5251

53-
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
54-
std::string method_name) {
52+
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(const torch::jit::script::Module& mod, std::string method_name) {
5553
auto lowered_mod = LowerModule(mod);
5654
auto g = lowered_mod.get_method(method_name).graph();
5755
LOG_GRAPH(*g);
@@ -62,10 +60,11 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(con
6260
lowering::LowerGraph(g);
6361
//=[torch::jit::FoldConvBatchNorm2d(lowered_mod);
6462
LOG_GRAPH("LibTorch Lowering");
65-
auto graph_and_parameters = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
63+
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
6664
// Is this necessary?
6765
lowering::LowerBlock(g->block());
68-
return graph_and_parameters;
66+
67+
return graph_and_ivalues;
6968
}
7069

7170

core/lowering/lowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ namespace lowering {
99
void LowerBlock(torch::jit::Block* b);
1010
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
1111
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
12-
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
13-
std::string method_name);
12+
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(const torch::jit::script::Module& mod,
13+
std::string method_name);
1414

1515
} // namespace lowering
1616
} // namespace core

cpp/api/include/trtorch/macros.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#define STR(x) XSTR(x)
2121

2222
#define TRTORCH_MAJOR_VERSION 0
23-
#define TRTORCH_MINOR_VERSION 0
24-
#define TRTORCH_PATCH_VERSION 3
23+
#define TRTORCH_MINOR_VERSION 1
24+
#define TRTORCH_PATCH_VERSION 0
2525
#define TRTORCH_VERSION STR(TRTORCH_MAJOR_VERSION) \
2626
"." STR(TRTORCH_MINOR_VERSION) \
2727
"." STR(TRTORCH_PATCH_VERSION)

docker/Dockerfile.20.07

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
FROM nvcr.io/nvidia/pytorch:20.07-py3
2+
3+
RUN apt-get update && apt-get install curl gnupg && rm -rf /var/lib/apt/lists/*
4+
5+
RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
6+
echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list
7+
8+
RUN apt-get update && apt-get install bazel-3.4.1 && rm -rf /var/lib/apt/lists/*
9+
RUN ln -s /usr/bin/bazel-3.4.1 /usr/bin/bazel
10+
11+
RUN pip install notebook
12+
13+
COPY . /opt/trtorch
14+
RUN rm /opt/trtorch/WORKSPACE
15+
COPY ./docker/WORKSPACE.cu11.docker /opt/trtorch/WORKSPACE
16+
17+
# Workaround for bazel expecting both static and shared versions, we only use shared libraries inside container
18+
RUN cp /usr/lib/x86_64-linux-gnu/libnvinfer.so /usr/lib/x86_64-linux-gnu/libnvinfer_static.a
19+
20+
WORKDIR /opt/trtorch
21+
RUN bazel build //:libtrtorch --compilation_mode opt
22+
23+
WORKDIR /opt/trtorch/py
24+
25+
# Locale is not set by default
26+
RUN apt-get update && apt-get install -y locales ninja-build && rm -rf /var/lib/apt/lists/* && locale-gen en_US.UTF-8
27+
ENV LANG en_US.UTF-8
28+
ENV LANGUAGE en_US:en
29+
ENV LC_ALL en_US.UTF-8
30+
RUN python3 setup.py install --use-cxx11-abi
31+
32+
RUN conda init bash
33+
34+
ENV LD_LIBRARY_PATH /opt/conda/lib/python3.6/site-packages/torch/lib:$LD_LIBRARY_PATh
35+
36+
WORKDIR /opt/trtorch/notebooks

docker/Dockerfile.docs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ FROM nvcr.io/nvidia/tensorrt:20.03-py3
33
RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add -
44
RUN echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list
55

6-
RUN apt update && apt install bazel-3.3.1
7-
RUN ln -s /usr/bin/bazel-3.3.1 /usr/bin/bazel
6+
RUN apt update && apt install bazel-3.4.1
7+
RUN ln -s /usr/bin/bazel-3.4.1 /usr/bin/bazel
88

99

1010
COPY ./py/requirements.txt requirements.txt

0 commit comments

Comments
 (0)