Skip to content

Adds basic scripting support #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
68f0317
feat(//core/lowering): Fuse aten::addmm branches into a single
narendasan May 31, 2020
17099fa
docs(//core/lowering): Document new aten::addmm fusion pass
narendasan May 31, 2020
60df888
feat(prim::NumToTensor): Implement evaluator for NumToTensor
narendasan Jun 1, 2020
670817c
feat(aten::zeros): Implement aten::zeros evaluator
narendasan Jun 1, 2020
0f63ffa
feat(aten::to): Remove remaining typecast operators (should be a very
narendasan Jun 2, 2020
2f394fb
docs: Update docs on new lowering passes
narendasan Jun 2, 2020
2cc3226
feat(//core/conversion/evaluators): Adds new applicability filters for
narendasan Jun 3, 2020
c83447e
fix(aten::size, other aten evaluators): Removes aten::size converter in
narendasan Jun 3, 2020
ca2b5f9
fix(//core/lowering): Conv2D -> _convolution pass was triggering conv
narendasan Jun 3, 2020
0014b84
feat(//core/lowering): Adds peephole optimization pass
narendasan Jun 5, 2020
d351717
feat(//core/conversion/evaluators): adding support for common evaluation
narendasan Jun 5, 2020
dcb1474
feat(//core/conversion): Adds the ability to evaluate loops
narendasan Jun 5, 2020
6bd1a3f
fix(//core): Do not compile hidden methods
narendasan Jun 5, 2020
3abfcaf
tests: Add scripted modules to the testsuite
narendasan Jun 5, 2020
961f9d8
docs: Update docs on new lowering passes
narendasan Jun 5, 2020
5b2a2ec
refactor(//core): Update some typos and small code changes for the PR
narendasan Jun 5, 2020
111857c
refactor(//core/lowering): Uncomment RemoveTo
narendasan Jun 5, 2020
80d5069
docs: One pass left out, links were wrong
narendasan Jun 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,16 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
for (const torch::jit::script::Method& method : mod.get_methods()) {
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);
// Don't convert hidden methods
if (method.name().rfind("_", 0)) {
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);
}
}

return new_mod;
Expand Down
86 changes: 79 additions & 7 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,57 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
std::back_inserter(input_output_pairs),
[](auto in, auto out){
return std::make_pair(in, out);
});

for (auto p : input_output_pairs) {
auto input = ctx->evaluated_value_map[p.first];
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
}
}

// TODO: With functionalization pass we may be able to make this into a regular evaluator later
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
auto start_cond = ctx->evaluated_value_map[n->input(1)];
ctx->evaluated_value_map[n->blocks()[0]->inputs()[0]] = torch::jit::IValue(0);
auto trip_count = ctx->evaluated_value_map[n->blocks()[0]->inputs()[0]];

MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);

LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());

while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
for (auto bn : n->blocks()[0]->nodes()) {
auto eval = EvaluateNode(ctx, bn);
if (eval) {
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
}
}

MapIValues(ctx, n->blocks()[0]->outputs(), n->outputs(), 1, 0);
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
trip_count.swap(new_trip_count);
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
}
}

void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
LOG_INFO(ctx->logger, "Converting Block");

Expand All @@ -202,7 +253,19 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
for (const auto n : nodes) {
bool to_eval = evaluators::shouldEvalAtConversionTime(n);
bool blacklisted = isNodeConversionBlacklisted(n);
if (!to_eval && !blacklisted) {
if (n->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (to_eval) {
auto eval = EvaluateNode(ctx, n);
if (eval) {
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->AssociateValueAndIValue(n->output(0), eval.value());
}
} else if (!blacklisted) {
// Should error out if something fails
AddLayer(ctx, n);
} else {
Expand Down Expand Up @@ -237,22 +300,29 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
return engine;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
bool supported = true;
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
std::set<std::string> unsupported_ops;
for (const auto n : b->nodes()) {
if (!OpSupported(n)) {
if (!OpSupported(n) && n->kind() != torch::jit::prim::Loop) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.VerifyCoverterSupportForBlock");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
supported = false;
}
for (const auto sub_b : n->blocks()) {
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
unsupported_ops.insert(sub_b_unsupported_ops.begin(), sub_b_unsupported_ops.end());
}
}
return unsupported_ops;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
auto unsupported_ops = GetUnsupportedOpsInBlock(b);

if (!supported) {
if (unsupported_ops.size() != 0) {
std::stringstream unsupported_msg;
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl;
for (auto s : unsupported_ops) {
Expand All @@ -261,8 +331,10 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
unsupported_msg << "You can either implement converters for these ops in your application or request implementation" << std::endl;
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
LOG_ERROR(unsupported_msg.str());
return false;
} else {
return true;
}
return supported;
}

} // namespace conversion
Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversion_blacklist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
"aten::backward",
"aten::save",
"aten::contiguous",
"aten::to",
"prim::RaiseException",
"prim::Print",
"prim::device",
Expand Down
1 change: 0 additions & 1 deletion core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ cc_library(
"impl/matrix_multiply.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/shape.cpp",
"impl/shuffle.cpp",
"impl/softmax.cpp",
"impl/unary.cpp",
Expand Down
32 changes: 0 additions & 32 deletions core/conversion/converters/impl/shape.cpp

This file was deleted.

1 change: 1 addition & 0 deletions core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
srcs = [
"NodeEvaluatorRegistry.cpp",
"prim.cpp",
"aten.cpp"
],
deps = [
"//core/util:prelude",
Expand Down
56 changes: 44 additions & 12 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,59 @@ namespace core {
namespace conversion {
namespace evaluators {
namespace {
using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, NodeEvaluator>;
using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, EvalRegistration>;

bool FindInVec(std::vector<c10::OperatorName>& names, c10::OperatorName target) {
for (auto n : names) {
if (n == target) {
return true;
}
}
return false;
}

class NodeEvaluatorRegistry {
public:
void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) {
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
evaluator_lut_[node_kind] = std::move(evaluator);
evaluator_lut_[node_kind] = std::move(eval_reg);
}

NodeEvaluator GetEvaluator(const torch::jit::NodeKind node_kind) {
NodeEvaluator FindEvaluator(const torch::jit::Node* n) {
auto node_kind = n->kind();
auto iter = evaluator_lut_.find(node_kind);
if (iter == evaluator_lut_.end()) {
LOG_ERROR("Requested evaluator for " << node_kind.toQualString() << ", but no such evaluator was found");
return nullptr;
}
return iter->second;
auto eval_reg = iter->second;
if (eval_reg.options.use()) {
for (auto o : n->outputs()) {
if (eval_reg.options.blacklisted_output_types.find(o->type()) != eval_reg.options.blacklisted_output_types.end()) {
return nullptr;
}
}

if (eval_reg.options.valid_schemas.size() != 0) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Evaluator for " << node_kind.toQualString() << "only runs on certain schemas, but schema for node is not retrievable");
if (!FindInVec(eval_reg.options.valid_schemas, schema->operator_name())) {
return nullptr;
}
}
}

return eval_reg.evaluator;
}

NodeEvaluator GetEvaluator(const torch::jit::Node* n) {
auto evaluator = FindEvaluator(n);
TRTORCH_CHECK(evaluator, "Requested evaluator for " << n->kind().toQualString() << ", but no such evaluator was found");
return evaluator;
}

bool EvalAtConversionTime(const torch::jit::Node* n) {
auto eval_at_conversion = evaluator_lut_.find(n->kind());
if (eval_at_conversion == evaluator_lut_.end()) {
auto evaluator = FindEvaluator(n);
if (evaluator == nullptr) {
return false;
} else {
return true;
Expand All @@ -58,16 +90,16 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
}

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
auto evaluator = get_evaluator_registry().GetEvaluator(n->kind());
auto evaluator = get_evaluator_registry().GetEvaluator(n);
return evaluator(n, args);
}

void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator) {
get_evaluator_registry().RegisterEvaluator(node_kind, evaluator);
void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
get_evaluator_registry().RegisterEvaluator(node_kind, std::move(eval_reg));
}

void register_node_evaluator(EvalRegistration r) {
register_node_evaluator(r.kind, r.evaluator);
register_node_evaluator(r.kind, std::move(r));
}

RegisterNodeEvaluators&& RegisterNodeEvaluators::evaluator(EvalRegistration r) && {
Expand Down
Loading