diff --git a/.gitignore b/.gitignore index 2fc1b55b80..c757ba7d1e 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ cpp/ptq/datasets/data/ tests/accuracy/datasets/data/* ._.DS_Store *.tar.gz +*.tgz \ No newline at end of file diff --git a/BUILD b/BUILD index 28eaf3862c..48f2ae5338 100644 --- a/BUILD +++ b/BUILD @@ -11,7 +11,7 @@ pkg_tar( "//core/conversion/evaluators:include", "//core/execution:include", "//core/lowering:include", - "//core/lowering/irfusers:include", + "//core/lowering/passes:include", "//core/util:include", "//core/util/logging:include" ], diff --git a/README.md b/README.md index fbb7fff688..c06c02c5f4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ > Ahead of Time (AOT) compiling for PyTorch JIT -TRTorch is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, TRTorch is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. TRTorch operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/F16) and other settings for your module. +TRTorch is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, TRTorch is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. TRTorch operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/F16/INT8) and other settings for your module. More Information / System Architecture: @@ -35,17 +35,17 @@ auto results = trt_mod.forward({in_tensor}); | Platform | Support | | -------- | ------- | | Linux AMD64 / GPU | **Supported** | -| Linux aarch64 / GPU | **Planned/Possible with Native Compiation and small modifications to the build system** | +| Linux aarch64 / GPU | **Planned/Possible with Native Compiation but untested** | | Linux aarch64 / DLA | **Planned/Possible with Native Compilation but untested** | | Windows / GPU | - | | Linux ppc64le / GPU | - | ### Dependencies -- Libtorch 1.4.0 -- CUDA 10.1 -- cuDNN 7.6 -- TensorRT 6.0.1 +- Libtorch 1.5.0 +- CUDA 10.2 +- cuDNN 7.6.5 +- TensorRT 7.0.0 ## Prebuilt Binaries @@ -53,10 +53,71 @@ Releases: https://github.com/NVIDIA/TRTorch/releases ## Compiling TRTorch -Install TensorRT, CUDA and cuDNN on the system before starting to compile. +### Installing Dependencies +You need to start by having CUDA installed on the system, Libtorch will automatically be pulled for you by bazel, +then you have two options. + +#### 1. Building using cuDNN & TensorRT tarball distributions + +> This is recommended so as to build TRTorch hermetically and insures any bugs are not caused by version issues + +> Make sure when running TRTorch that these versions of the libraries are prioritized in your `$LD_LIBRARY_PATH` + +1. You need to download the tarball distributions of TensorRT and cuDNN from the NVIDIA website. + - https://developer.nvidia.com/cudnn + - https://developer.nvidia.com/tensorrt +2. Place these files in a directory (the directories `thrid_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]` exist for this purpose) +3. Compile using: +``` shell +bazel build //:libtrtorch --compilation_mode opt --distdir thrid_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu] +``` + +#### 2. Building using locally installed cuDNN & TensorRT + +> If you find bugs and you compiled using this method please disclose it in the issue +> (an `ldd` dump would be nice too) + +1. Install TensorRT, CUDA and cuDNN on the system before starting to compile. +2. In `WORKSPACE` comment out +```py +# Downloaded distributions to use with --distdir +http_archive( + name = "cudnn", + urls = ["",], + + build_file = "@//third_party/cudnn/archive:BUILD", + sha256 = "", + strip_prefix = "cuda" +) + +http_archive( + name = "tensorrt", + urls = ["",], + + build_file = "@//third_party/tensorrt/archive:BUILD", + sha256 = "", + strip_prefix = "TensorRT-" +) +``` +and uncomment +```py +# Locally installed dependencies +new_local_repository( + name = "cudnn", + path = "/usr/", + build_file = "@//third_party/cudnn/local:BUILD" +) + +new_local_repository( + name = "tensorrt", + path = "/usr/", + build_file = "@//third_party/tensorrt/local:BUILD" +) +``` +3. Compile using: ``` shell -bazel build //:libtrtorch --compilation_mode=opt +bazel build //:libtrtorch --compilation_mode opt ``` ### Debug build @@ -84,9 +145,13 @@ Thanks for wanting to contribute! There are two main ways to handle supporting a ### In my application? -> The Node Converter Registry is not exposed in the top level API but you can try using the internal headers shipped with the tarball. +> The Node Converter Registry is not exposed in the top level API but in the internal headers shipped with the tarball. + +You can register a converter for your op using the `NodeConverterRegistry` inside your application. + +## Known Limitations -You can register a converter for your op using the NodeConverterRegistry inside your application. +- You cannot use both Adaptive Pooling in PyTorch and also use TRTorch Dynamic input shape ## Structure of the repo diff --git a/WORKSPACE b/WORKSPACE index 9714cccae4..f93c2f7e1c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,15 +16,6 @@ py_repositories() load("@rules_python//python:pip.bzl", "pip_repositories", "pip_import") pip_repositories() -http_archive( - name = "libtorch", - build_file = "@//third_party/libtorch:BUILD", - strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.4.0.zip"], - sha256 = "f214bfde532877aa5d4e0803e51a28fa8edd97b6a44b6615f75a70352b6b542e" -) - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "rules_pkg", url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.4/rules_pkg-0.2.4.tar.gz", @@ -34,24 +25,53 @@ http_archive( load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") rules_pkg_dependencies() +# CUDA should be installed on the system locally new_local_repository( name = "cuda", - path = "/usr/local/cuda-10.1/targets/x86_64-linux/", + path = "/usr/local/cuda-10.2/targets/x86_64-linux/", build_file = "@//third_party/cuda:BUILD", ) -new_local_repository( +http_archive( + name = "libtorch", + build_file = "@//third_party/libtorch:BUILD", + strip_prefix = "libtorch", + urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.5.0.zip"], + sha256 = "0efdd4e709ab11088fa75f0501c19b0e294404231442bab1d1fb953924feb6b5" +) + +# Downloaded distributions to use with --distdir +http_archive( name = "cudnn", - path = "/usr/", - build_file = "@//third_party/cudnn:BUILD" + urls = ["https://developer.nvidia.com/compute/machine-learning/cudnn/secure/7.6.5.32/Production/10.2_20191118/cudnn-10.2-linux-x64-v7.6.5.32.tgz",], + + build_file = "@//third_party/cudnn/archive:BUILD", + sha256 = "600267f2caaed2fd58eb214ba669d8ea35f396a7d19b94822e6b36f9f7088c20", + strip_prefix = "cuda" ) -new_local_repository( - name = "tensorrt", - path = "/usr/", - build_file = "@//third_party/tensorrt:BUILD" +http_archive( + name = "tensorrt", + urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.0/7.0.0.11/tars/TensorRT-7.0.0.11.Ubuntu-18.04.x86_64-gnu.cuda-10.2.cudnn7.6.tar.gz",], + + build_file = "@//third_party/tensorrt/archive:BUILD", + sha256 = "c7d73b2585b18aae68b740249efa8c8ba5ae852abe9a023720595432a8eb4efd", + strip_prefix = "TensorRT-7.0.0.11" ) +## Locally installed dependencies +# new_local_repository( +# name = "cudnn", +# path = "/usr/", +# build_file = "@//third_party/cudnn/local:BUILD" +#) + +# new_local_repository( +# name = "tensorrt", +# path = "/usr/", +# build_file = "@//third_party/tensorrt/local:BUILD" +#) + git_repository( name = "googletest", remote = "https://github.com/google/googletest", diff --git a/core/compiler.cpp b/core/compiler.cpp index 9e3a69033b..f8d2861e8e 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -7,12 +7,11 @@ #include "ATen/core/function_schema.h" -#include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/pass_manager.h" +#include "torch/csrc/jit/frontend/function_schema_parser.h" +#include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/passes/pass_manager.h" #include "torch/csrc/jit/passes/lower_graph.h" #include "torch/csrc/jit/passes/graph_fuser.h" -#include "torch/csrc/jit/script/module.h" -#include "torch/csrc/jit/script/function_schema_parser.h" #include "core/util/prelude.h" #include "core/compiler.h" @@ -42,71 +41,59 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr& g, std::string& serialized_engine) { execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine); - auto schema = execution::GetEngineFunctionSchema(uid); auto num_io = execution::GetEngineIO(uid); auto self = g->addInput("self.1"); self->setType(mod.type()); - std::vector graph_inputs; + + auto id_val = g->insertConstant(uid); + + std::vector engine_inputs; + engine_inputs.push_back(id_val); + for (uint64_t i = 0; i < num_io.first; i++) { auto in_val = g->addInput(""); in_val->setType(c10::TensorType::get()); - graph_inputs.push_back(in_val); + engine_inputs.push_back(in_val); } - auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef(graph_inputs), num_io.second); + auto engine_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef(engine_inputs), num_io.second); g->block()->appendNode(engine_node); for (auto o : engine_node->outputs()) { g->registerOutput(o); } + LOG_DEBUG(*g << "(AddEngineToGraph)\n"); + return; } bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) { - auto g = mod.get_method(method_name).graph(); - // Go through PyTorch Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); + // Go through Lowering to simplify graph and extract weight parameters + auto graph_and_parameters = lowering::Lower(mod, method_name); - g = graph_and_parameters.first; - - // Go through TRTorch Lowering to reformat graph to be conversion friendly - // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) - lowering::LowerGraph(g); - - auto params = graph_and_parameters.second; - auto named_params = conversion::get_named_params(g->inputs(), params); + auto g = graph_and_parameters.first; LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n"); - // Is this necessary? - lowering::LowerBlock(g->block()); - return conversion::VerifyConverterSupportForBlock(g->block()); } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo cfg) { - auto convert_cfg = std::move(cfg.convert_info); - - auto g = mod.get_method(method_name).graph(); - // Go through PyTorch Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); - g = graph_and_parameters.first; - - // Go through TRTorch Lowering to reformat graph to be conversion friendly - // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) - lowering::LowerGraph(g); + // Go through Lowering to simplify graph and extract weight parameters + auto graph_and_parameters = lowering::Lower(mod, method_name); + auto convert_cfg = std::move(cfg.convert_info); + auto g = graph_and_parameters.first; auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); + LOG_INFO(*g << "(CompileGraph)\n"); - // Is this necessary? - lowering::LowerBlock(g->block()); auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params); return std::move(engine); } @@ -115,7 +102,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, ExtraInfo cfg) { // TODO: Should be doing a functional transform but need PR #31978 // [jit] More robust mangling - // torch::jit::script::Module new_mod = mod.clone(); + //torch::jit::script::Module new_mod = mod.clone(); torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt"); std::vector> graphs; for (const torch::jit::script::Method& method : mod.get_methods()) { diff --git a/core/compiler.h b/core/compiler.h index a8b248d3a2..f9ff400159 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -1,7 +1,7 @@ #pragma once #include -#include "torch/csrc/jit/script/module.h" +#include "torch/csrc/jit/api/module.h" #include "core/conversion/conversion.h" namespace trtorch { diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 3d0dca2f9f..2d6b704f40 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -14,9 +14,7 @@ namespace conversion { bool isNodeConversionBlacklisted(const torch::jit::Node* n); bool OpSupported(const torch::jit::Node* n) { - bool evalable = evaluators::shouldEvalAtConversionTime(n); - bool convertable = converters::node_is_convertable(n); - return evalable || convertable; + return evaluators::shouldEvalAtConversionTime(n) || converters::node_is_convertable(n); } c10::optional EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level=0, int limit=10) { @@ -75,8 +73,12 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { LOG_DEBUG(ctx->logger, "Node input is a value that needs to be evaluated"); auto eval = EvaluateNode(ctx, input_node); if (eval) { - LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value()); - ctx->evaluated_value_map[input] = std::move(eval.value()); + if (!eval.value().isTensor()) { + LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value()); + } else { + LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); + } + ctx->AssociateValueAndIValue(input, eval.value()); node_args.push_back(&(ctx->evaluated_value_map[input])); } else { LOG_DEBUG(ctx->logger, "Found the value is None");; @@ -158,6 +160,10 @@ void AddInputs(ConversionCtx* ctx, TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)"); ctx->cfg->addOptimizationProfile(profile); + // TODO: Enable in TRT 7.1 + // if (ctx->op_precision == nvinfer1::DataType::kINT8) { + // ctx->cfg->setCalibrationProfile(profile); + // } } void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { @@ -208,9 +214,7 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver } for (const auto n : nodes) { - if (converters::node_is_convertable(n)) { - ctx->CheckLayerAddition(n); - } + ctx->CheckLayerAddition(n); } auto outputs = b->outputs(); diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index c7a50a6319..f60e203f55 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -3,7 +3,7 @@ #include #include "NvInfer.h" -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" #include "core/conversion/conversionctx/ConversionCtx.h" namespace torch { diff --git a/core/conversion/conversion_blacklist.cpp b/core/conversion/conversion_blacklist.cpp index c20ccc7db7..9cdf5bd016 100644 --- a/core/conversion/conversion_blacklist.cpp +++ b/core/conversion/conversion_blacklist.cpp @@ -1,12 +1,12 @@ #include #include -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" namespace trtorch { namespace core { namespace conversion { - + const std::unordered_set& get_non_convertable_nodes() { // Set of nodes that should not invoke a converter or evaluator static std::unordered_set nonconvertable_nodes = { diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index a3a5ddfc01..7037e9a512 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -60,6 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) input_type = nvinfer1::DataType::kFLOAT; break; } + op_precision = settings.op_precision; if (settings.refit) { cfg->setFlag(nvinfer1::BuilderFlag::kREFIT); @@ -103,18 +104,27 @@ nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Valu return tensor; } +torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue ivalue) { + this->evaluated_value_map[value] = std::move(ivalue); + return &this->evaluated_value_map[value]; +} + std::string ConversionCtx::SerializeEngine() { auto engine = builder->buildEngineWithConfig(*net, *cfg); auto serialized_engine = engine->serialize(); + engine->destroy(); return std::string((const char*)serialized_engine->data(), serialized_engine->size()); } bool ConversionCtx::CheckLayerAddition(const torch::jit::Node* n) { for (auto out : n->outputs()) { - auto iter = this->value_tensor_map.find(out); - if (iter == this->value_tensor_map.end()) { - LOG_WARNING("Node " << util::node_info(n) << " output: " << out->debugName() << " does not have a coresponding output, may potentially indicate a defective converter"); - return false; + auto iter_t = this->value_tensor_map.find(out); + if (iter_t == this->value_tensor_map.end()) { + auto iter_iv = this->evaluated_value_map.find(out); + if (iter_iv == this->evaluated_value_map.end()) { + LOG_WARNING("Node " << util::node_info(n) << " output: " << out->debugName() << " does not have a coresponding value or tensor, may potentially indicate a defective evaluator or converter"); + return false; + } } } return true; diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index b7922a319d..1d2581fdc9 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -5,12 +5,11 @@ #include //#include "ATen/ATen.h" -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" #include "NvInfer.h" #include "core/util/prelude.h" - namespace trtorch { namespace core { namespace conversion { @@ -39,6 +38,7 @@ struct ConversionCtx { ConversionCtx(BuilderSettings settings); std::string SerializeEngine(); nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor); + torch::jit::IValue* AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue tensor); bool CheckLayerAddition(const torch::jit::Node* n); ~ConversionCtx(); @@ -47,6 +47,7 @@ struct ConversionCtx { nvinfer1::INetworkDefinition* net; nvinfer1::IBuilderConfig* cfg; nvinfer1::DataType input_type; + nvinfer1::DataType op_precision; BuilderSettings settings; util::logging::TRTorchLogger logger; // Pointers to data that needs to remain alive until conversion is done diff --git a/core/conversion/converters/Arg.cpp b/core/conversion/converters/Arg.cpp index af23713ea1..90e4a423dc 100644 --- a/core/conversion/converters/Arg.cpp +++ b/core/conversion/converters/Arg.cpp @@ -89,6 +89,7 @@ std::string Arg::type_name() const { } const torch::jit::IValue* Arg::IValue() const { + TRTORCH_CHECK(isIValue(), "Requested IValue from Arg, however arg type is " << type_name()); if (type_ == Type::kIValue) { return ptr_.ivalue; } else { @@ -97,6 +98,7 @@ const torch::jit::IValue* Arg::IValue() const { } nvinfer1::ITensor* Arg::ITensor() const { + TRTORCH_CHECK(isITensor(), "Requested ITensor from Arg, however arg type is " << type_name()); if (type_ == Type::kITensor) { return ptr_.tensor; } else { diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 43b30857ee..464ad44550 100644 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -15,8 +15,11 @@ cc_library( "impl/conv_deconv.cpp", "impl/element_wise.cpp", "impl/linear.cpp", + "impl/matrix_multiply.cpp", "impl/pooling.cpp", "impl/reduce.cpp", + "impl/shape.cpp", + "impl/shuffle.cpp", "impl/softmax.cpp", "impl/unary.cpp", ], diff --git a/core/conversion/converters/NodeConverterRegistry.cpp b/core/conversion/converters/NodeConverterRegistry.cpp index ff175c6023..04bddad002 100644 --- a/core/conversion/converters/NodeConverterRegistry.cpp +++ b/core/conversion/converters/NodeConverterRegistry.cpp @@ -1,6 +1,6 @@ #include "core/util/prelude.h" #include "core/conversion/converters/converters.h" -#include "torch/csrc/jit/script/function_schema_parser.h" +#include "torch/csrc/jit/frontend/function_schema_parser.h" namespace trtorch { namespace core { diff --git a/core/conversion/converters/converters.h b/core/conversion/converters/converters.h index 1b6f3b916f..6c151b061a 100644 --- a/core/conversion/converters/converters.h +++ b/core/conversion/converters/converters.h @@ -3,9 +3,10 @@ #include #include -#include "torch/csrc/jit/custom_operator.h" +#include "torch/csrc/jit/runtime/custom_operator.h" #include "ATen/core/function_schema.h" +#include "core/util/prelude.h" #include "core/conversion/conversionctx/ConversionCtx.h" namespace trtorch { diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index ee0e60a9d7..375e7a2d8f 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -8,11 +8,24 @@ namespace converters { namespace impl { namespace { -nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, float scalar=1) { +nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name, float scalar=1) { auto self_dims = self->getDimensions(); + auto self_dims_vec = util::toVec(self_dims); auto other_dims = other->getDimensions(); + auto other_dims_vec = util::toVec(other_dims); + auto other_batch = other_dims_vec[0]; + + // TODO: Proper broadcast check + TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims) || util::volume(self_dims) == util::volume(other_dims) / other_batch, "Found inputs to elementwise operation do not have the same number of elements or is not broadcastable:\n Found: self " << self_dims << " other " << other_dims); + + if (self_dims != other_dims) { + LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims); + auto self_shuffle = ctx->net->addShuffle(*self); + self_shuffle->setReshapeDimensions(util::toDimsPad(self_dims_vec, other_dims_vec.size())); + self_shuffle->setName(std::string("[Reshape self to " + util::toStr(self_dims) + " for broadcasting (" + name + ")]").c_str()); + self = self_shuffle->getOutput(0); + } - TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims); nvinfer1::ILayer* ele; if (scalar != 1) { @@ -63,7 +76,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); - auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar); + auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar); TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n); @@ -80,7 +93,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); - auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar); + auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar); TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n); @@ -97,7 +110,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); - auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, scalar); + auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, util::node_info(n), scalar); TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n); @@ -113,7 +126,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // Should implement self / other auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); + auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); @@ -129,7 +142,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // TODO: Remove with functionalization auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); + auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); @@ -145,7 +158,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // Should implement self * other auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); @@ -161,7 +174,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // TODO: Remove with functionalization auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp index 0b5b4e957a..f4c49ec020 100644 --- a/core/conversion/converters/impl/linear.cpp +++ b/core/conversion/converters/impl/linear.cpp @@ -64,4 +64,4 @@ auto linear_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // namespace trtorch diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp new file mode 100644 index 0000000000..c6d2d99f1e --- /dev/null +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -0,0 +1,55 @@ +#include "core/util/prelude.h" +#include "core/conversion/converters/converters.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +auto mm_registrations = RegisterNodeConversionPatterns() + .pattern({ + "aten::matmul(Tensor self, Tensor other) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + nvinfer1::ITensor* self; + if (args[0].isIValue()) { + auto t = args[0].unwrapToTensor(); + auto t_weights = Weights(ctx, t); + auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); + TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n); + const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str()); + self = const_layer->getOutput(0); + } else { + self = args[0].ITensor(); + } + LOG_DEBUG("self tensor shape: " << self->getDimensions()); + + nvinfer1::ITensor* other; + if (args[1].isIValue()) { + auto t = args[1].unwrapToTensor(); + auto t_weights = Weights(ctx, t); + auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); + TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n); + const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str()); + other = const_layer->getOutput(0); + } else { + other = args[1].ITensor(); + } + LOG_DEBUG("other tensor shape: " << other->getDimensions()); + + auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE); + TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); + mm_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch \ No newline at end of file diff --git a/core/conversion/converters/impl/shape.cpp b/core/conversion/converters/impl/shape.cpp new file mode 100644 index 0000000000..d5b3577a34 --- /dev/null +++ b/core/conversion/converters/impl/shape.cpp @@ -0,0 +1,32 @@ +#include "core/conversion/converters/converters.h" + +#include "torch/torch.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +static auto shape_registrations = RegisterNodeConversionPatterns() + .pattern({ + // To use in static input size cases (explicit batch) + "aten::size.int(Tensor self, int dim) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + + auto size = in_shape[args[1].unwrapToInt()]; + + ctx->AssociateValueAndIValue(n->outputs()[0], size); + LOG_DEBUG("Output Value: " << size); + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp new file mode 100644 index 0000000000..8a8853fe2d --- /dev/null +++ b/core/conversion/converters/impl/shuffle.cpp @@ -0,0 +1,54 @@ +#include "core/conversion/converters/converters.h" + +#include "torch/torch.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +static auto shuffle_registrations = RegisterNodeConversionPatterns() + .pattern({ + "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto start_dim = args[1].unwrapToInt(); + auto end_dim = args[2].unwrapToInt(); + auto in_shape = util::toVec(in->getDimensions()); + auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes(); + + auto shuffle = ctx->net->addShuffle(*in); + TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + shuffle->setReshapeDimensions(util::toDims(out_shape)); + shuffle->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::reshape(Tensor self, int[] shape) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes(); + + auto shuffle = ctx->net->addShuffle(*in); + TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + shuffle->setReshapeDimensions(util::toDims(new_shape)); + shuffle->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/softmax.cpp b/core/conversion/converters/impl/softmax.cpp index c138c759f0..35f6f04ef1 100644 --- a/core/conversion/converters/impl/softmax.cpp +++ b/core/conversion/converters/impl/softmax.cpp @@ -1,4 +1,3 @@ -#include "core/util/prelude.h" #include "core/conversion/converters/converters.h" namespace trtorch { @@ -29,12 +28,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() auto softmax = ctx->net->addSoftMax(*in); TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n); - - if (!softmax) { - LOG_ERROR("Unable to create softmax layer from node: " << *n); - return false; - } - LOG_WARNING("Disregarding dtype argument, please verify"); + LOG_DEBUG("Disregarding dtype argument"); if (shape.size() > 3) { softmax->setAxes(1 << (dim)); @@ -69,4 +63,4 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // namespace trtorch diff --git a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp index a810c44584..6f387326fd 100644 --- a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp +++ b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp @@ -1,7 +1,7 @@ #include -#include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/constants.h" +#include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/constants.h" #include "ATen/core/functional.h" #include "ATen/core/ivalue.h" #include "ATen/core/List.h" @@ -41,7 +41,7 @@ class NodeEvaluatorRegistry { return true; } } - + private: EvaluatorLUT evaluator_lut_; }; diff --git a/core/conversion/evaluators/evaluators.h b/core/conversion/evaluators/evaluators.h index 021bc84e62..b2b89c28a0 100644 --- a/core/conversion/evaluators/evaluators.h +++ b/core/conversion/evaluators/evaluators.h @@ -3,7 +3,7 @@ #include #include -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" namespace trtorch { namespace core { @@ -19,7 +19,7 @@ typedef std::map kwargs; // when writing evaluators typedef std::function(const torch::jit::Node*, const kwargs&)> NodeEvaluator; -struct EvalRegistration { +struct EvalRegistration { torch::jit::NodeKind kind; NodeEvaluator evaluator; }; diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index ee3c76c4cf..64662005b3 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -1,5 +1,5 @@ -#include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/constants.h" +#include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/constants.h" #include "ATen/core/functional.h" #include "ATen/core/ivalue.h" #include "ATen/core/List.h" diff --git a/core/execution/BUILD b/core/execution/BUILD index 452bef5332..ed53cdefcf 100644 --- a/core/execution/BUILD +++ b/core/execution/BUILD @@ -14,7 +14,8 @@ cc_library( "@tensorrt//:nvinfer", "@libtorch//:libtorch", "//core/util:prelude" - ] + ], + alwayslink = True, ) load("@rules_pkg//:pkg.bzl", "pkg_tar") diff --git a/core/execution/TRTEngine.cpp b/core/execution/TRTEngine.cpp index 40fbf75151..3370ea6f5b 100644 --- a/core/execution/TRTEngine.cpp +++ b/core/execution/TRTEngine.cpp @@ -1,7 +1,7 @@ #include #include "NvInfer.h" -#include "torch/csrc/jit/script/function_schema_parser.h" +#include "torch/csrc/jit/frontend/function_schema_parser.h" #include "core/util/prelude.h" #include "core/execution/execution.h" @@ -10,51 +10,7 @@ namespace trtorch { namespace core { namespace execution { -std::string slugify(std::string s) { - std::replace(s.begin(), s.end(), '.', '_'); - return s; -} - -c10::FunctionSchema GenerateEngineFunctionSchema(EngineID id, nvinfer1::ICudaEngine* engine, uint64_t num_inputs, uint64_t num_outputs) { - std::stringstream ss; - ss << "trt::execute_engine_" << std::hex << id << "("; - - std::stringstream in_ss; - std::stringstream out_ss; - - uint64_t inputs_parsed = 0; - uint64_t outputs_parsed = 0; - for (int i = 0; i < engine->getNbBindings(); i++) { - if (engine->bindingIsInput(i)) { - in_ss << "Tensor in_"; - in_ss << slugify(engine->getBindingName(i)); - if (inputs_parsed + 1 < num_inputs) { - in_ss << ", "; - inputs_parsed++; - } - } else { - out_ss << "Tensor"; - if (outputs_parsed + 1 < num_outputs) { - out_ss << ", "; - outputs_parsed++; - } - } - } - - ss << in_ss.str(); - ss << ") -> ("; - ss << out_ss.str(); - ss << ')'; - return torch::jit::parseSchema(ss.str()); -} - -TRTEngine::TRTEngine() - : schema(torch::jit::parseSchema("trt::noop() -> ()")) { -} - -TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine) - : schema(torch::jit::parseSchema("trt::noop() -> ()")) { // Need a better default - +TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine) { rt = nvinfer1::createInferRuntime(logger); cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()); @@ -73,7 +29,6 @@ TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine) } } num_io = std::make_pair(inputs, outputs); - schema = GenerateEngineFunctionSchema(id, cuda_engine, inputs, outputs); } TRTEngine& TRTEngine::operator=(const TRTEngine& other) { @@ -82,7 +37,6 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) { cuda_engine = other.cuda_engine; exec_ctx = other.exec_ctx; num_io = other.num_io; - schema = other.schema; return (*this); } diff --git a/core/execution/TRTEngineManager.cpp b/core/execution/TRTEngineManager.cpp index 1705582ec8..27a6aeff28 100644 --- a/core/execution/TRTEngineManager.cpp +++ b/core/execution/TRTEngineManager.cpp @@ -16,33 +16,29 @@ class TRTEngineManager { TRTEngine* get_engine(EngineID uid) { auto iter = engine_registry_.find(uid); - TRTORCH_ASSERT(iter != engine_registry_.end(), "Unabled to find requested engine (ID: " << uid << ") in TensorRT Execution Manager"); + TRTORCH_ASSERT(iter != engine_registry_.end(), "Unabled to find requested engine (ID: " << uid << ") in TensorRT Execution Manager"); return &(iter->second); } - + // TODO: Should we have standing engines ready to run or should we be creating execution contexts JIT? EngineID register_engine(std::string& serialized_engine) { auto engine = TRTEngine(logger_, serialized_engine); EngineID uid = engine.id; engine_registry_[uid] = std::move(engine); LOG_DEBUG(logger_, "Registering new engine (ID: " << std::hex << uid << ") in TensorRT Execution Manager"); - RegisterEngineOp(engine); return uid; } void deregister_engine(EngineID uid) { auto iter = engine_registry_.find(uid); - if (iter == engine_registry_.end()) { - //TODO: Exception - LOG_ERROR(logger_, "Unabled to find requested engine (ID: " << uid << ") in TensorRT Execution Manager"); - } else { - auto engine = iter->second; - // Doing this here since for some reason the destructor causes segfaults - engine.exec_ctx->destroy(); - engine.cuda_engine->destroy(); - engine_registry_.erase(uid); - } + TRTORCH_ASSERT(iter != engine_registry_.end(), "Unabled to find requested engine (ID: " << uid << ") in TensorRT Execution Manager"); + + auto engine = iter->second; + // Doing this here since for some reason the destructor causes segfaults + engine.exec_ctx->destroy(); + engine.cuda_engine->destroy(); + engine_registry_.erase(uid); } private: @@ -61,25 +57,20 @@ uint64_t RegisterEngineFromSerializedEngine(std::string& serialized_engine) { } nvinfer1::ICudaEngine* GetCudaEngine(EngineID id) { - // Assuming exception will be thrown inside the manager if there is no corresponding engine + // Assuming exception will be thrown inside the manager if there is no corresponding engine return get_engine_manager().get_engine(id)->cuda_engine; } nvinfer1::IExecutionContext* GetExecCtx(EngineID id) { - // Assuming exception will be thrown inside the manager if there is no corresponding engine + // Assuming exception will be thrown inside the manager if there is no corresponding engine return get_engine_manager().get_engine(id)->exec_ctx; } std::pair GetEngineIO(EngineID id) { - // Assuming exception will be thrown inside the manager if there is no corresponding engine + // Assuming exception will be thrown inside the manager if there is no corresponding engine return get_engine_manager().get_engine(id)->num_io; } -c10::FunctionSchema GetEngineFunctionSchema(EngineID id) { - // Assuming exception will be thrown inside the manager if there is no corresponding engine - return get_engine_manager().get_engine(id)->schema; -} - void DeregisterEngine(EngineID id) { get_engine_manager().deregister_engine(id); } diff --git a/core/execution/execution.h b/core/execution/execution.h index 7f484071eb..8c50dd4207 100644 --- a/core/execution/execution.h +++ b/core/execution/execution.h @@ -15,10 +15,9 @@ struct TRTEngine { nvinfer1::ICudaEngine* cuda_engine; nvinfer1::IExecutionContext* exec_ctx; std::pair num_io; - c10::FunctionSchema schema; EngineID id; - TRTEngine(); + TRTEngine() = default; TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine); TRTEngine& operator=(const TRTEngine& other); }; @@ -28,7 +27,6 @@ uint64_t RegisterEngineFromSerializedEngine(std::string& serialized_engine); nvinfer1::ICudaEngine* GetCudaEngine(EngineID id); nvinfer1::IExecutionContext* GetExecCtx(EngineID id); std::pair GetEngineIO(EngineID id); -c10::FunctionSchema GetEngineFunctionSchema(EngineID id); void DeregisterEngine(EngineID id); } // namespace execution diff --git a/core/execution/register_trt_op.cpp b/core/execution/register_trt_op.cpp index 495c01c0e6..6085a59974 100644 --- a/core/execution/register_trt_op.cpp +++ b/core/execution/register_trt_op.cpp @@ -1,7 +1,7 @@ #include "c10/cuda/CUDAStream.h" #include "torch/torch.h" -#include "torch/csrc/jit/custom_operator.h" +#include "torch/csrc/jit/runtime/custom_operator.h" #include "core/util/prelude.h" #include "core/execution/execution.h" @@ -9,7 +9,7 @@ namespace trtorch { namespace core { namespace execution { - +namespace { std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pair io, std::vector& inputs) { std::vector gpu_handles; @@ -47,45 +47,40 @@ std::vector RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai return outputs; } -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } +// Switched to a global operator because op implementations need to be non-capturing lambdas in PYT 1.5.0+ +torch::jit::RegisterOperators jit_registry({ + torch::jit::Operator( + "trt::execute_engine(int id, ...) -> ...", + [](torch::jit::Stack& stack) -> int { + size_t num_inputs = torch::jit::pop(stack).toInt(); + // Verify calling convention (right to left or left to right) + std::vector inputs; + for (uint64_t i = 0; i < num_inputs - 1; i++) { + at::Tensor in; + torch::jit::pop(stack, in); + inputs.insert(inputs.begin(), std::move(in)); + } -// The other way to do this is to register a generic op something liek -// trt::execute_engine(int id, Tensor input, ...) -> (Tensor...) but not sure -// how well that would work -void RegisterEngineOp(TRTEngine& engine) { - EngineID id = engine.id; - torch::jit::RegisterOperators jit_registry({ - torch::jit::Operator( - engine.schema, - [id](torch::jit::Stack& stack) { - LOG_DEBUG("Attempting to run engine (ID: " << std::hex << id << ")"); - auto io = GetEngineIO(id); - auto num_in = io.first; - auto num_out = io.second; - // Verify calling convention (right to left or left to right) - std::vector inputs; - for (uint64_t i = 0; i < num_in; i++) { - at::Tensor in; - torch::jit::pop(stack, in); - inputs.insert(inputs.begin(), std::move(in)); - } + int64_t id = torch::jit::pop(stack).toInt(); + LOG_DEBUG("Attempting to run engine (ID: " << std::hex << id << ")"); + auto io = GetEngineIO(id); + auto num_out = io.second; - auto ctx = GetExecCtx(id); - auto outputs = RunCudaEngine(ctx, io, inputs); - for (uint64_t o = 0; o < num_out; o++) { - torch::jit::push(stack, std::move(outputs[o])); - } - return 0; - }, - aliasAnalysisFromSchema()) - }); -} + auto ctx = GetExecCtx(id); + auto outputs = RunCudaEngine(ctx, io, inputs); + for (uint64_t o = 0; o < num_out; o++) { + torch::jit::push(stack, std::move(outputs[o])); + } + return 0; + }, + aliasAnalysisFromSchema()) + }); +} // namespace } // namespace execution } // namespace core } // namespace trtorch diff --git a/core/lowering/BUILD b/core/lowering/BUILD index 9a0c53c13b..4d4822a2a9 100644 --- a/core/lowering/BUILD +++ b/core/lowering/BUILD @@ -8,11 +8,14 @@ cc_library( srcs = [ "lowering.cpp", "drop_unused_nodes.cpp", + "register_const_op.cpp" ], deps = [ "@libtorch//:libtorch", - "//core/lowering/irfusers" - ] + "//core/lowering/passes", + "//core/util:prelude" + ], + alwayslink = True ) load("@rules_pkg//:pkg.bzl", "pkg_tar") diff --git a/core/lowering/drop_unused_nodes.cpp b/core/lowering/drop_unused_nodes.cpp index baa631cfa7..a99a6d210b 100644 --- a/core/lowering/drop_unused_nodes.cpp +++ b/core/lowering/drop_unused_nodes.cpp @@ -1,4 +1,4 @@ -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" namespace torch { namespace jit { @@ -14,7 +14,7 @@ namespace trtorch { namespace core { namespace lowering { -// From torch/csrc/jit/interpreter.cpp +// From torch/csrc/jit/interpreter.cpp void DropUnusedNodes(torch::jit::Block* b) { auto create_drop_if_unused = [&](at::ArrayRef values) -> torch::jit::Node* { std::vector to_drop; @@ -26,7 +26,7 @@ void DropUnusedNodes(torch::jit::Block* b) { return nullptr; return b->owningGraph()->create(torch::jit::prim::Drop, to_drop, 0); }; - + if (auto d = create_drop_if_unused(b->inputs())) { b->prependNode(d); } diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index f7f1254e00..1fbbb107e4 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -1,8 +1,13 @@ -#include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/fuse_linear.h" +#include "torch/csrc/jit/passes/freeze_module.h" +#include "torch/csrc/jit/passes/lower_graph.h" +#include "torch/csrc/jit/passes/quantization.h" +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "core/util/prelude.h" #include "core/lowering/lowering.h" -#include "core/lowering/irfusers/irfusers.h" +#include "core/lowering/passes/passes.h" namespace trtorch { namespace core { @@ -15,14 +20,43 @@ void LowerBlock(torch::jit::Block* b) { } void LowerGraph(std::shared_ptr& g) { + torch::jit::EliminateRedundantGuards(g); + passes::EliminateExceptionOrPassPattern(g); torch::jit::FuseLinear(g); - irfusers::RemoveDropout(g); - irfusers::FuseFlattenLinear(g); - irfusers::ExpandLogSoftmax(g); + passes::RemoveDropout(g); + passes::FuseFlattenLinear(g); + passes::UnpackAddMM(g); + passes::UnpackLogSoftmax(g); + //passes::RemoveDimExeception(g); //irfusers::UnpackBatchNorm(g); - //torch::jit::EliminateDeadCode(g); + torch::jit::EliminateDeadCode(g); + LOG_GRAPH(*g); } - + +torch::jit::Module LowerModule(const torch::jit::script::Module& mod) { + auto mod_ = torch::jit::freeze_module(mod); + return mod_; +} + +std::pair, std::vector> Lower(const torch::jit::script::Module& mod, + std::string method_name) { + auto lowered_mod = LowerModule(mod); + auto g = lowered_mod.get_method(method_name).graph(); + LOG_GRAPH(*g); + + // Go through TRTorch Lowering to reformat graph to be conversion friendly + // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) + LOG_GRAPH("TRTorch Graph Lowering"); + lowering::LowerGraph(g); + //=[torch::jit::FoldConvBatchNorm2d(lowered_mod); + LOG_GRAPH("LibTorch Lowering"); + auto graph_and_parameters = torch::jit::LowerGraph(*g, lowered_mod._ivalue()); + // Is this necessary? + lowering::LowerBlock(g->block()); + return graph_and_parameters; +} + + } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index ed34ee30a0..79f07cb5ec 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -1,13 +1,16 @@ #pragma once #include -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" namespace trtorch { namespace core { namespace lowering { - + void LowerBlock(torch::jit::Block* b); void LowerGraph(std::shared_ptr& g); +torch::jit::Module LowerModule(const torch::jit::script::Module& mod); +std::pair, std::vector> Lower(const torch::jit::script::Module& mod, + std::string method_name); } // namespace lowering } // namespace core diff --git a/core/lowering/irfusers/BUILD b/core/lowering/passes/BUILD similarity index 52% rename from core/lowering/irfusers/BUILD rename to core/lowering/passes/BUILD index 71899dfee6..99e92c1829 100644 --- a/core/lowering/irfusers/BUILD +++ b/core/lowering/passes/BUILD @@ -1,17 +1,20 @@ package(default_visibility = ["//visibility:public"]) cc_library( - name = "irfusers", + name = "passes", hdrs = [ - "irfusers.h", + "passes.h", ], srcs = [ + "exception_elimination.cpp", "fuse_flatten_linear.cpp", - "expand_log_softmax.cpp", "remove_dropout.cpp", - "unpack_batch_norm.cpp" + "unpack_addmm.cpp", + "unpack_batch_norm.cpp", + "unpack_log_softmax.cpp", ], deps = [ + "//core/util:prelude", "@libtorch//:libtorch", ] ) @@ -20,7 +23,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar") pkg_tar( name = "include", - package_dir = "core/lowering/irfusers/", - srcs = ["irfusers.h"], + package_dir = "core/lowering/passes/", + srcs = ["passes.h"], ) diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp new file mode 100644 index 0000000000..1f9fb35a65 --- /dev/null +++ b/core/lowering/passes/exception_elimination.cpp @@ -0,0 +1,85 @@ +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "torch/csrc/jit/ir/alias_analysis.h" +#include "torch/csrc/jit/jit_log.h" +#include "torch/csrc/jit/passes/constant_propagation.h" +#include "torch/csrc/jit/passes/peephole.h" +#include "torch/csrc/jit/runtime/graph_executor.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" + +#include "core/util/prelude.h" + +#include + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { +namespace { +using namespace torch::jit; +struct ExceptionOrPassPatternElimination { + ExceptionOrPassPatternElimination(std::shared_ptr graph) + : graph_(std::move(graph)) {} + + void run() { + findExceptionOrPassNodes(graph_->block()); + torch::jit::EliminateDeadCode(graph_); + LOG_GRAPH("Post exeception or pass elimination: " << *graph_); + } + +private: + bool isExceptionOrPassNode(Node* n) { + /// Check if this Node hosts a pattern like so: + /// = prim::If(%5958) + /// block0(): + /// = prim::RaiseException(%45) + /// -> () + /// block1(): + /// -> () + if (n->blocks().size() != 2) { + return false; + } + auto arm1 = n->blocks()[0]; + auto arm2 = n->blocks()[1]; + if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) { + // Make sure that the node doesn't actually produce any Value that are used by other nodes + return false; + } + + auto arm1_start = arm1->nodes().begin(); + + if ((*arm1_start)->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) { + // Make sure that block0 is solely just the exception and the return + return false; + } + + if ((*(arm2->nodes().begin()))->kind() != prim::Return) { + // Make sure that block1 is solely the return + return false; + } + + return true; + } + + void findExceptionOrPassNodes(Block* b) { + for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { + auto n = *it; + if (n->kind() == prim::If && isExceptionOrPassNode(n)) { + LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)"); + it.destroyCurrent(); + } + } + } + + std::shared_ptr graph_; +}; +} // namespace + +void EliminateExceptionOrPassPattern(std::shared_ptr graph) { + ExceptionOrPassPatternElimination eppe(std::move(graph)); + eppe.run(); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/core/lowering/irfusers/fuse_flatten_linear.cpp b/core/lowering/passes/fuse_flatten_linear.cpp similarity index 93% rename from core/lowering/irfusers/fuse_flatten_linear.cpp rename to core/lowering/passes/fuse_flatten_linear.cpp index 5b8c3899ec..4a8a7112d9 100644 --- a/core/lowering/irfusers/fuse_flatten_linear.cpp +++ b/core/lowering/passes/fuse_flatten_linear.cpp @@ -1,10 +1,12 @@ #include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/subgraph_rewrite.h" +#include "core/util/prelude.h" + namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void FuseFlattenLinear(std::shared_ptr& graph) { //TensorRT implicitly adds a flatten layer infront of FC layers if necessary @@ -33,13 +35,15 @@ void FuseFlattenLinear(std::shared_ptr& graph) { torch::jit::SubgraphRewriter flatten_linear_to_linear; flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear); flatten_linear_to_linear.runOnGraph(graph); - + torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear; flatten_linear_bias_none_to_linear.RegisterRewritePattern( flatten_linear_bias_none_pattern, fused_linear_bias_none); flatten_linear_bias_none_to_linear.runOnGraph(graph); + LOG_GRAPH("Post flatten linear: " << *graph); } -} // namespace irfusers + +} // namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/irfusers/irfusers.h b/core/lowering/passes/passes.h similarity index 57% rename from core/lowering/irfusers/irfusers.h rename to core/lowering/passes/passes.h index d793e3176a..d4efc198db 100644 --- a/core/lowering/irfusers/irfusers.h +++ b/core/lowering/passes/passes.h @@ -1,17 +1,19 @@ #pragma once -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void FuseFlattenLinear(std::shared_ptr& graph); -void ExpandLogSoftmax(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); +void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); - +void UnpackLogSoftmax(std::shared_ptr& graph); +void EliminateExceptionOrPassPattern(std::shared_ptr graph); + } // namespace irfusers } // namespace lowering } // namespace core diff --git a/core/lowering/irfusers/remove_dropout.cpp b/core/lowering/passes/remove_dropout.cpp similarity index 86% rename from core/lowering/irfusers/remove_dropout.cpp rename to core/lowering/passes/remove_dropout.cpp index fcd1d07aa0..ccac4c3d8a 100644 --- a/core/lowering/irfusers/remove_dropout.cpp +++ b/core/lowering/passes/remove_dropout.cpp @@ -1,10 +1,12 @@ #include #include +#include "core/util/prelude.h" + namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { void RemoveDropout(std::shared_ptr& graph) { std::string dropout_pattern = R"IR( @@ -14,15 +16,16 @@ void RemoveDropout(std::shared_ptr& graph) { std::string no_dropout_pattern = R"IR( graph(%input, %4, %5): return (%input))IR"; - + // replace matmul + add pattern to linear torch::jit::SubgraphRewriter remove_dropout; remove_dropout.RegisterRewritePattern( dropout_pattern, no_dropout_pattern); remove_dropout.runOnGraph(graph); + LOG_GRAPH("Post remove dropout: " << *graph); } -} // namespace irfusers +} // namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/passes/unpack_addmm.cpp b/core/lowering/passes/unpack_addmm.cpp new file mode 100644 index 0000000000..ce2cc046e8 --- /dev/null +++ b/core/lowering/passes/unpack_addmm.cpp @@ -0,0 +1,35 @@ +#include "torch/csrc/jit/passes/fuse_linear.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void UnpackAddMM(std::shared_ptr& graph) { + //TensorRT implicitly adds a flatten layer infront of FC layers if necessary + std::string addmm_pattern = R"IR( + graph(%b, %x, %w, %1): + %out: Tensor = aten::addmm(%b, %x, %w, %1, %1) + return (%out))IR"; + std::string mm_add_pattern = R"IR( + graph(%b, %x, %w, %1): + %mm: Tensor = aten::matmul(%x, %w) + %bias: Tensor = trt::const(%b) + %out: Tensor = aten::add_(%bias, %mm, %1) + return (%out))IR"; + + + torch::jit::SubgraphRewriter unpack_addmm; + unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern); + unpack_addmm.runOnGraph(graph); + LOG_GRAPH("Post unpack addmm: " << *graph); +} + + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/core/lowering/irfusers/unpack_batch_norm.cpp b/core/lowering/passes/unpack_batch_norm.cpp similarity index 69% rename from core/lowering/irfusers/unpack_batch_norm.cpp rename to core/lowering/passes/unpack_batch_norm.cpp index 76a8ff0739..19a9ad29bd 100644 --- a/core/lowering/irfusers/unpack_batch_norm.cpp +++ b/core/lowering/passes/unpack_batch_norm.cpp @@ -1,47 +1,27 @@ -#include "torch/csrc/jit/custom_operator.h" -#include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/subgraph_rewrite.h" -namespace torch { -namespace jit { - -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; -} - -RegisterOperators trt_const_op_reg({ - Operator( - "trt::const(Tensor val) -> Tensor", - [](Stack& stack) { - return 0; //nop - }, - aliasAnalysisFromSchema())}); - -} // namespace jit -} // namespace torch +#include "core/util/prelude.h" namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { // // May be abusing aten::_tensor_to_list(Tensor self) -> int[] // // Treating it as an emit_constant by the converters // // We could register a custom op (trt::emit_constant) which we can use to convert -// // constant tensors to TRT ITensors +// // constant tensors to TRT ITensors void UnpackBatchNorm(std::shared_ptr& graph) { // Convert BatchNorm into individual operators // batch_norm = gamma * (in - mu) / sqrt(var + epsilon) + beta std::string batch_norm_pattern = R"IR( - graph(%input, %gamma, %beta, %mean, + graph(%input, %gamma, %beta, %mean, %var, %training, %momentum, %eps, %cudnn): %1 = aten::batch_norm(%input, %gamma, %beta, %mean, %var, %training, %momentum, %eps, %cudnn) return (%1))IR"; std::string expanded_batch_norm_pattern = R"IR( - graph(%input, %gamma, %beta, %mean, + graph(%input, %gamma, %beta, %mean, %var, %training, %momentum, %eps, %cudnn): %gamma_trt = trt::const(%gamma) %beta_trt = trt::const(%beta) @@ -57,12 +37,13 @@ void UnpackBatchNorm(std::shared_ptr& graph) { %7: Scalar = prim::Constant[value=1]() %8 = aten::add(%6, %beta_trt, %7) return(%8))IR"; - + torch::jit::SubgraphRewriter unpack_batch_norm; unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern); unpack_batch_norm.runOnGraph(graph); + LOG_GRAPH("Post unpack batchnorm: " << *graph); } -} // Namespace Irfusers +} // Namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/irfusers/expand_log_softmax.cpp b/core/lowering/passes/unpack_log_softmax.cpp similarity index 89% rename from core/lowering/irfusers/expand_log_softmax.cpp rename to core/lowering/passes/unpack_log_softmax.cpp index c9c3cc27ce..7a53e64e9e 100644 --- a/core/lowering/irfusers/expand_log_softmax.cpp +++ b/core/lowering/passes/unpack_log_softmax.cpp @@ -1,17 +1,19 @@ #include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/subgraph_rewrite.h" +#include "core/util/prelude.h" + namespace trtorch { namespace core { namespace lowering { -namespace irfusers { +namespace passes { -void ExpandLogSoftmax(std::shared_ptr& graph) { +void UnpackLogSoftmax(std::shared_ptr& graph) { // Its easier for TensorRT if we seperate softmax and log // There might need to be a reshape inserted see: // https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593 // Should the reshapes be added here or in the converter? - + // TODO: In the future this should be removed for a deicated log_softmax converter (more efficent) // But its easier to stand up a working system if the number of op converters is lower std::string logsoftmax_pattern = R"IR( @@ -33,19 +35,20 @@ void ExpandLogSoftmax(std::shared_ptr& graph) { %dtype : int? = prim::Constant() %softmax = aten::softmax(%input, %dim, %dtype) %log_softmax = aten::log(%softmax) - return (%log_softmax))IR"; + return (%log_softmax))IR"; torch::jit::SubgraphRewriter logsoftmax_to_softmax_log; logsoftmax_to_softmax_log.RegisterRewritePattern(logsoftmax_pattern, softmax_log_pattern); logsoftmax_to_softmax_log.runOnGraph(graph); - + torch::jit::SubgraphRewriter logsoftmax_none_to_softmax_log_none; logsoftmax_none_to_softmax_log_none.RegisterRewritePattern( logsoftmax_none_pattern, softmax_log_none_pattern); logsoftmax_none_to_softmax_log_none.runOnGraph(graph); + LOG_GRAPH("Post unpack logsoftmax: " << *graph); } -} // namespace irfusers +} // namespace passes } // namespace lowering } // namespace core } // namespace trtorch diff --git a/core/lowering/register_const_op.cpp b/core/lowering/register_const_op.cpp new file mode 100644 index 0000000000..3952e978d5 --- /dev/null +++ b/core/lowering/register_const_op.cpp @@ -0,0 +1,21 @@ +#include "torch/csrc/jit/runtime/custom_operator.h" + +namespace torch { +namespace jit { + +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; +} + +/// Op marks a Tensor to be conveted from an Torch Tensor +/// to a TRT constant Tensor +RegisterOperators trt_const_op_reg({ + Operator( + "trt::const(Tensor val) -> Tensor", + [](Stack& stack) { + return 0; //noop + }, + aliasAnalysisFromSchema())}); + +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/core/util/jit_util.h b/core/util/jit_util.h index d77d3229a6..538be211a6 100644 --- a/core/util/jit_util.h +++ b/core/util/jit_util.h @@ -3,7 +3,7 @@ #include #include -#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/ir/ir.h" namespace trtorch { namespace core { diff --git a/core/util/logging/TRTorchLogger.cpp b/core/util/logging/TRTorchLogger.cpp index d3968c9ee3..678506c09f 100644 --- a/core/util/logging/TRTorchLogger.cpp +++ b/core/util/logging/TRTorchLogger.cpp @@ -17,39 +17,36 @@ namespace util { namespace logging { TRTorchLogger::TRTorchLogger(std::string prefix, Severity severity, bool color) - : prefix_(prefix), reportable_severity_(severity), color_(color) {} + : prefix_(prefix), reportable_severity_((LogLevel) severity), color_(color) {} TRTorchLogger::TRTorchLogger(std::string prefix, LogLevel lvl, bool color) - : prefix_(prefix), reportable_severity_((Severity) lvl), color_(color) {} + : prefix_(prefix), reportable_severity_(lvl), color_(color) {} void TRTorchLogger::log(LogLevel lvl, std::string msg) { - Severity severity = (Severity) lvl; - log(severity, msg.c_str()); -} - -void TRTorchLogger::log(Severity severity, const char* msg) { // suppress messages with severity enum value greater than the reportable - if (severity > reportable_severity_) { + if (lvl > reportable_severity_) { return; } if (color_) { - switch (severity) { - case Severity::kINTERNAL_ERROR: std::cerr << TERM_RED; break; - case Severity::kERROR: std::cerr << TERM_RED; break; - case Severity::kWARNING: std::cerr << TERM_YELLOW; break; - case Severity::kINFO: std::cerr << TERM_GREEN; break; - case Severity::kVERBOSE: std::cerr << TERM_MAGENTA; break; + switch (lvl) { + case LogLevel::kINTERNAL_ERROR: std::cerr << TERM_RED; break; + case LogLevel::kERROR: std::cerr << TERM_RED; break; + case LogLevel::kWARNING: std::cerr << TERM_YELLOW; break; + case LogLevel::kINFO: std::cerr << TERM_GREEN; break; + case LogLevel::kDEBUG: std::cerr << TERM_MAGENTA; break; + case LogLevel::kGRAPH: std::cerr << TERM_NORMAL; break; default: break; } } - switch (severity) { - case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break; - case Severity::kERROR: std::cerr << "ERROR: "; break; - case Severity::kWARNING: std::cerr << "WARNING: "; break; - case Severity::kINFO: std::cerr << "INFO: "; break; - case Severity::kVERBOSE: std::cerr << "DEBUG: "; break; + switch (lvl) { + case LogLevel::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break; + case LogLevel::kERROR: std::cerr << "ERROR: "; break; + case LogLevel::kWARNING: std::cerr << "WARNING: "; break; + case LogLevel::kINFO: std::cerr << "INFO: "; break; + case LogLevel::kDEBUG: std::cerr << "DEBUG: "; break; + case LogLevel::kGRAPH: std::cerr << "GRAPH: "; break; default: std::cerr << "UNKNOWN: "; break; } @@ -60,16 +57,22 @@ void TRTorchLogger::log(Severity severity, const char* msg) { std::cerr << prefix_ << msg << std::endl; } + +void TRTorchLogger::log(Severity severity, const char* msg) { + LogLevel lvl = (LogLevel) severity; + log(lvl, std::string(msg)); +} + void TRTorchLogger::set_logging_prefix(std::string prefix) { prefix_ = prefix; } void TRTorchLogger::set_reportable_severity(Severity severity) { - reportable_severity_ = severity; + reportable_severity_ = (LogLevel) severity; } void TRTorchLogger::set_reportable_log_level(LogLevel lvl) { - reportable_severity_ = (Severity) lvl; + reportable_severity_ = lvl; } void TRTorchLogger::set_is_colored_output_on(bool colored_output_on) { @@ -81,11 +84,11 @@ std::string TRTorchLogger::get_logging_prefix() { } nvinfer1::ILogger::Severity TRTorchLogger::get_reportable_severity() { - return reportable_severity_; + return (Severity) reportable_severity_; } LogLevel TRTorchLogger::get_reportable_log_level() { - return (LogLevel) reportable_severity_; + return reportable_severity_; } bool TRTorchLogger::get_is_colored_output_on() { diff --git a/core/util/logging/TRTorchLogger.h b/core/util/logging/TRTorchLogger.h index 0a3edc3203..55582a5277 100644 --- a/core/util/logging/TRTorchLogger.h +++ b/core/util/logging/TRTorchLogger.h @@ -8,12 +8,13 @@ namespace core { namespace util { namespace logging { -enum LogLevel { +enum class LogLevel : uint8_t { kINTERNAL_ERROR = (int) nvinfer1::ILogger::Severity::kINTERNAL_ERROR, kERROR = (int) nvinfer1::ILogger::Severity::kERROR, kWARNING = (int) nvinfer1::ILogger::Severity::kWARNING, kINFO = (int) nvinfer1::ILogger::Severity::kINFO, kDEBUG = (int) nvinfer1::ILogger::Severity::kVERBOSE, + kGRAPH }; // Logger for TensorRT info/warning/errors @@ -38,10 +39,10 @@ class TRTorchLogger : public nvinfer1::ILogger { private: std::string prefix_; - Severity reportable_severity_; + LogLevel reportable_severity_; bool color_; }; - + TRTorchLogger& get_logger(); } // namespace logging diff --git a/core/util/macros.h b/core/util/macros.h index 2c295988d1..ccea57e513 100644 --- a/core/util/macros.h +++ b/core/util/macros.h @@ -11,18 +11,21 @@ l.log(sev, ss.str()); \ } while (0) +#define LOG_GRAPH_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s) #define LOG_DEBUG_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kDEBUG, s) #define LOG_INFO_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINFO, s) #define LOG_WARNING_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kWARNING, s) #define LOG_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kERROR, s) -#define LOG_INTERNAL_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINTERNAL_ERROR, s) +#define LOG_INTERNAL_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINTERNAL_ERROR, s) +#define LOG_GRAPH_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s) #define LOG_DEBUG_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kDEBUG, s) #define LOG_INFO_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINFO, s) #define LOG_WARNING_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kWARNING, s) #define LOG_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kERROR, s) #define LOG_INTERNAL_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINTERNAL_ERROR, s) +#define LOG_GRAPH(...) GET_MACRO(__VA_ARGS__, LOG_GRAPH_OWN, LOG_GRAPH_GLOBAL)(__VA_ARGS__) #define LOG_DEBUG(...) GET_MACRO(__VA_ARGS__, LOG_DEBUG_OWN, LOG_DEBUG_GLOBAL)(__VA_ARGS__) #define LOG_INFO(...) GET_MACRO(__VA_ARGS__, LOG_INFO_OWN, LOG_INFO_GLOBAL)(__VA_ARGS__) #define LOG_WARNING(...) GET_MACRO(__VA_ARGS__, LOG_WARNING_OWN, LOG_WARNING_GLOBAL)(__VA_ARGS__) @@ -47,7 +50,7 @@ << ", consider filing a bug: https://www.github.com/NVIDIA/TRTorch/issues\n" \ << __VA_ARGS__); \ } - + #define TRTORCH_CHECK(cond, ...) \ if (!(cond)) { \ TRTORCH_THROW_ERROR("Expected " << #cond \ diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 89214e5efd..2f6706c51a 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -59,6 +59,29 @@ nvinfer1::Dims toDims(c10::List l) { return dims; } +nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to) { + if (l.size() > pad_to) { + LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad"); + return toDims(l); + } + + if (pad_to > nvinfer1::Dims::MAX_DIMS) { + //TODO: Handle this with exceptions or whatever + LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); + } + + nvinfer1::Dims dims; + dims.nbDims = pad_to; + for (size_t i = 0; i < pad_to - l.size(); i++) { + dims.d[i] = 1; + } + + for (size_t i = pad_to - l.size(); i < pad_to; i++) { + dims.d[i] = l[i - (pad_to - l.size())]; + } + return dims; +} + std::vector toVec(nvinfer1::Dims d) { std::vector dims; for (int i = 0; i < d.nbDims; i++) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index bf8ea5b224..09cf5ff418 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -78,6 +78,7 @@ namespace util { int64_t volume(const nvinfer1::Dims& d); nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to); +nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to); nvinfer1::Dims toDims(c10::IntArrayRef l); nvinfer1::Dims toDims(c10::List l); nvinfer1::DimsHW toDimsHW(c10::List l); diff --git a/cpp/api/include/trtorch/logging.h b/cpp/api/include/trtorch/logging.h index 8841c0d47c..d841be33c1 100644 --- a/cpp/api/include/trtorch/logging.h +++ b/cpp/api/include/trtorch/logging.h @@ -9,11 +9,12 @@ namespace logging { * Emum for setting message severity */ enum Level { - kINTERNAL_ERROR, - kERROR, - kWARNING, - kINFO, - kDEBUG, + kINTERNAL_ERROR, // Only print messages for internal errors + kERROR, // Print all internal errors and errors (default) + kWARNING, // Print warnings and errors + kINFO, // Print all info, warnings and errors + kDEBUG, // Print all debug info, info, warnings and errors + kGRAPH, // Print everything including the intermediate graphs of the lowering phase }; // Are these ones necessary for the user? @@ -35,7 +36,7 @@ TRTORCH_API void set_reportable_log_level(Level lvl); TRTORCH_API void set_is_colored_output_on(bool colored_output_on); /** - * @brief Get the current reportable log level + * @brief Get the current reportable log level */ TRTORCH_API Level get_reportable_log_level(); @@ -45,10 +46,10 @@ TRTORCH_API Level get_reportable_log_level(); TRTORCH_API bool get_is_colored_output_on(); /** - * @brief Adds a message to the global log + * @brief Adds a message to the global log * - * @param lvl: trtorch::logging::Level - Severity of the message - * @param msg: std::string - Message to be logged + * @param lvl: trtorch::logging::Level - Severity of the message + * @param msg: std::string - Message to be logged */ // Dont know if we want this? TRTORCH_API void log(Level lvl, std::string msg); diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index b2c99a136e..865fe021c5 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -16,9 +16,7 @@ namespace torch { namespace jit { struct Graph; -namespace script { struct Module; -} // namespace script } // namespace jit } // namespace torch @@ -240,7 +238,7 @@ TRTORCH_API void dump_build_info(); * * Will print out a list of unsupported operators if the graph is unsupported */ -TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name); +TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name); /** * @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -253,7 +251,7 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& mo * * Converts specifically the forward method of a TorchScript Module */ -TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info); +TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::Module& module, ExtraInfo info); /** * @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT @@ -266,7 +264,7 @@ TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Mo * and will convert selected method to a serialized TensorRT engine which can be run with * TensorRT */ -TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info); +TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info); namespace ptq { /** diff --git a/cpp/api/src/extra_info.cpp b/cpp/api/src/extra_info.cpp index f7fc5709e9..5bc12fa204 100644 --- a/cpp/api/src/extra_info.cpp +++ b/cpp/api/src/extra_info.cpp @@ -1,4 +1,4 @@ -#include "torch/csrc/jit/script/module.h" +#include "torch/csrc/jit/api/module.h" #include "core/util/prelude.h" #include "core/compiler.h" diff --git a/cpp/api/src/logging.cpp b/cpp/api/src/logging.cpp index 25fbfd0bfe..2a82c19ce7 100644 --- a/cpp/api/src/logging.cpp +++ b/cpp/api/src/logging.cpp @@ -7,7 +7,7 @@ namespace logging { std::string get_logging_prefix() { return core::util::logging::get_logger().get_logging_prefix(); } - + void set_logging_prefix(std::string prefix) { core::util::logging::get_logger().set_logging_prefix(prefix); } @@ -27,6 +27,9 @@ void set_reportable_log_level(Level lvl) { case Level::kINFO: log_lvl = core::util::logging::LogLevel::kINFO; break; + case Level::kGRAPH: + log_lvl = core::util::logging::LogLevel::kGRAPH; + break; case Level::kDEBUG: default: log_lvl = core::util::logging::LogLevel::kDEBUG; @@ -50,12 +53,14 @@ Level get_reportable_log_level() { return Level::kWARNING; case core::util::logging::LogLevel::kINFO: return Level::kINFO; + case core::util::logging::LogLevel::kGRAPH: + return Level::kGRAPH; case core::util::logging::LogLevel::kDEBUG: default: return Level::kDEBUG; } } - + bool get_is_colored_output_on() { return core::util::logging::get_logger().get_is_colored_output_on(); } diff --git a/cpp/api/src/trtorch.cpp b/cpp/api/src/trtorch.cpp index bb8e5a7845..e6a1940db1 100644 --- a/cpp/api/src/trtorch.cpp +++ b/cpp/api/src/trtorch.cpp @@ -1,4 +1,4 @@ -#include "torch/csrc/jit/script/module.h" +#include "torch/csrc/jit/api/module.h" #include "core/util/prelude.h" #include "core/compiler.h" diff --git a/cpp/trtorchexec/main.cpp b/cpp/trtorchexec/main.cpp index f3909fef27..e6f18233e9 100644 --- a/cpp/trtorchexec/main.cpp +++ b/cpp/trtorchexec/main.cpp @@ -55,11 +55,13 @@ int main(int argc, const char* argv[]) { dims.push_back(v); } + std::cout << "Checking operator support" << std::endl; if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) { std::cerr << "Method is not currently supported by TRTorch" << std::endl; return -1; } + std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims); std::ofstream out("/tmp/engine_converted_from_jit.trt"); out << engine; @@ -75,7 +77,9 @@ int main(int argc, const char* argv[]) { std::vector jit_results; jit_results.push_back(jit_results_ivalues.toTensor()); + std::cout << "Compiling graph as module" << std::endl; auto trt_mod = trtorch::CompileGraph(mod, dims); + std::cout << "Running TRT module" << std::endl; torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues); std::vector trt_results; trt_results.push_back(trt_results_ivalues.toTensor()); diff --git a/tests/accuracy/accuracy_test.h b/tests/accuracy/accuracy_test.h index 229608de6d..28cfe9cdca 100644 --- a/tests/accuracy/accuracy_test.h +++ b/tests/accuracy/accuracy_test.h @@ -20,6 +20,7 @@ class AccuracyTests std::cerr << "error loading the model\n"; return; } + mod.eval(); } void TearDown() { diff --git a/tests/accuracy/test_fp16_accuracy.cpp b/tests/accuracy/test_fp16_accuracy.cpp index 7ebcc8b0fb..6de40a6c31 100644 --- a/tests/accuracy/test_fp16_accuracy.cpp +++ b/tests/accuracy/test_fp16_accuracy.cpp @@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { jit_total += targets.sizes()[0]; jit_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor jit_accuracy = jit_correct / jit_total; + torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; std::vector> input_shape = {{32, 3, 32, 32}}; auto extra_info = trtorch::ExtraInfo({input_shape}); @@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { trt_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor trt_accuracy = trt_correct / trt_total; + torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100; ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); } diff --git a/tests/accuracy/test_fp32_accuracy.cpp b/tests/accuracy/test_fp32_accuracy.cpp index b014340e82..d3d8bddb96 100644 --- a/tests/accuracy/test_fp32_accuracy.cpp +++ b/tests/accuracy/test_fp32_accuracy.cpp @@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { jit_total += targets.sizes()[0]; jit_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor jit_accuracy = jit_correct / jit_total; + torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; std::vector> input_shape = {{32, 3, 32, 32}}; auto extra_info = trtorch::ExtraInfo({input_shape}); @@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { trt_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor trt_accuracy = trt_correct / trt_total; + torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100; ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); } diff --git a/tests/accuracy/test_int8_accuracy.cpp b/tests/accuracy/test_int8_accuracy.cpp index 07d399c96d..aa4824948a 100644 --- a/tests/accuracy/test_int8_accuracy.cpp +++ b/tests/accuracy/test_int8_accuracy.cpp @@ -54,7 +54,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { jit_total += targets.sizes()[0]; jit_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor jit_accuracy = jit_correct / jit_total; + torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; // Compile Graph auto trt_mod = trtorch::CompileGraph(mod, extra_info); @@ -72,7 +72,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { trt_total += targets.sizes()[0]; trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); } - torch::Tensor trt_accuracy = trt_correct / trt_total; + torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100; ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); } diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index 8cbe9c4e68..b47692a886 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -1,49 +1,59 @@ load("//tests/core/converters:converter_test.bzl", "converter_test") converter_test( - name = "test_softmax" + name = "test_activation" ) converter_test( - name = "test_activation" + name = "test_conv" ) converter_test( - name = "test_pooling" + name = "test_element_wise" ) converter_test( - name = "test_unary" + name = "test_linear" ) converter_test( - name = "test_linear" + name = "test_matrix_multiply" ) converter_test( - name = "test_element_wise" + name = "test_pooling" ) converter_test( - name = "test_conv" + name = "test_reduce" ) converter_test( - name = "test_reduce" + name = "test_shuffle" +) + +converter_test( + name = "test_softmax" +) + +converter_test( + name = "test_unary" ) test_suite( - name = "test_converters", - tests = [ - ":test_softmax", - ":test_activation", - ":test_pooling", - ":test_unary", - ":test_linear", - ":test_element_wise", - ":test_conv", - ":test_reduce" - ] + name = "test_converters", + tests = [ + ":test_activation", + ":test_conv", + ":test_element_wise", + ":test_linear", + ":test_matrix_multiply", + ":test_pooling", + ":test_reduce", + ":test_shuffle", + ":test_softmax", + ":test_unary", + ] ) diff --git a/tests/core/converters/test_activation.cpp b/tests/core/converters/test_activation.cpp index bae82a4fe3..299420878c 100644 --- a/tests/core/converters/test_activation.cpp +++ b/tests/core/converters/test_activation.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -11,7 +11,7 @@ TEST(Converters, ATenReLUConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(-5, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -31,7 +31,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(-5, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -51,7 +51,7 @@ TEST(Converters, ATenTanhConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(-5, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -96,7 +96,7 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(-5, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); diff --git a/tests/core/converters/test_conv.cpp b/tests/core/converters/test_conv.cpp index 69f040bb43..154a5b06ca 100644 --- a/tests/core/converters/test_conv.cpp +++ b/tests/core/converters/test_conv.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -12,7 +12,7 @@ void conv_test_helper(std::string graph_ir) { auto g = std::make_shared(); - torch::jit::script::parseIR(graph_ir, &*g); + torch::jit::parseIR(graph_ir, &*g); auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA}); auto w = at::randint(1, 10, {8, 3, 5, 5}, {at::kCUDA}); @@ -54,7 +54,7 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) { return (%12))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA}); auto w = at::randint(1, 10, {8, 3, 5, 5}, {at::kCUDA}); @@ -96,7 +96,7 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) { return (%12))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(1, 2, {1, 1, 3, 3}, {at::kCUDA}); auto w = at::randint(1, 2, {4, 1, 2, 2}, {at::kCUDA}); @@ -136,7 +136,7 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) { return (%13))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(1, 10, {1, 3, 9, 9}, {at::kCUDA}); auto w = at::randint(1, 10, {4, 3, 3, 3}, {at::kCUDA}); @@ -179,7 +179,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) { return (%13))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA}); diff --git a/tests/core/converters/test_element_wise.cpp b/tests/core/converters/test_element_wise.cpp index db8ea15870..85cb90f6b7 100644 --- a/tests/core/converters/test_element_wise.cpp +++ b/tests/core/converters/test_element_wise.cpp @@ -1,12 +1,12 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" void pointwise_test_helper(std::string graph_ir) { auto g = std::make_shared(); - torch::jit::script::parseIR(graph_ir, &*g); + torch::jit::parseIR(graph_ir, &*g); auto in0 = at::randint(1, 5, {5}, {at::kCUDA}); auto in1 = at::randint(1, 5, {5}, {at::kCUDA}); diff --git a/tests/core/converters/test_linear.cpp b/tests/core/converters/test_linear.cpp index f6d543ab04..c668c03e63 100644 --- a/tests/core/converters/test_linear.cpp +++ b/tests/core/converters/test_linear.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -13,7 +13,7 @@ TEST(Converters, ATenLinearNoBiasConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); //Input Tensor needs to be 4D for TensorRT linear auto in = at::randint(1, 10, {1, 2}, {at::kCUDA}); @@ -40,7 +40,7 @@ TEST(Converters, ATenLinearBiasConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); // WARN: TRT expects a 4D input eventually, but pytorch does not require a channel dim auto in = at::randint(1, 5, {1, 3}, {at::kCUDA}); diff --git a/tests/core/converters/test_matrix_multiply.cpp b/tests/core/converters/test_matrix_multiply.cpp new file mode 100644 index 0000000000..8ec8035d8d --- /dev/null +++ b/tests/core/converters/test_matrix_multiply.cpp @@ -0,0 +1,27 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +// TODO: IR Parser doesnt work well with neg numbers +TEST(Converters, ATenMMConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in1 = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {3, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} diff --git a/tests/core/converters/test_pooling.cpp b/tests/core/converters/test_pooling.cpp index f278d8b65f..6e7b612612 100644 --- a/tests/core/converters/test_pooling.cpp +++ b/tests/core/converters/test_pooling.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -18,7 +18,7 @@ TEST(Converters, ATenMaxPool2DConvertsCorrectly) { return (%10))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); //PyTorch MaxPool needs a 3D input auto in = at::randint(-5, 5, {1, 4, 4}, at::kCUDA); @@ -42,7 +42,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) { return (%10))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); //PyTorch MaxPool needs a 3D input auto in = at::randint(-5, 5, {1, 12, 16}, at::kCUDA); diff --git a/tests/core/converters/test_reduce.cpp b/tests/core/converters/test_reduce.cpp index 81d6e6c606..64834890d4 100644 --- a/tests/core/converters/test_reduce.cpp +++ b/tests/core/converters/test_reduce.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -56,7 +56,7 @@ std::string gen_keepdim_graph(const std::string& op) { void test_body(const std::string& graph, at::Tensor& in) { auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); diff --git a/tests/core/converters/test_shuffle.cpp b/tests/core/converters/test_shuffle.cpp new file mode 100644 index 0000000000..7539dfcf39 --- /dev/null +++ b/tests/core/converters/test_shuffle.cpp @@ -0,0 +1,77 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +// TODO: IR Parser doesnt work well with neg numbers +TEST(Converters, ATenFlattenConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=1]() + %3 : Tensor = aten::flatten(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +// TODO: IR Parser doesnt work well with neg numbers +TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=2]() + %3 : Tensor = aten::flatten(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenReshapeConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=2]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : Tensor = aten::reshape(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file diff --git a/tests/core/converters/test_softmax.cpp b/tests/core/converters/test_softmax.cpp index 5f8186021c..87d5030cbf 100644 --- a/tests/core/converters/test_softmax.cpp +++ b/tests/core/converters/test_softmax.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -13,7 +13,7 @@ TEST(Converters, ATenSoftmax1DConvertsCorrectly) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(0, 5, {5}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -36,7 +36,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlySub3DIndex) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(0, 5, {1,2,2,2,2}, {at::kCUDA}); @@ -60,7 +60,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) { return (%3))IR"; auto g = std::make_shared(); - torch::jit::script::parseIR(graph, &*g); + torch::jit::parseIR(graph, &*g); auto in = at::randint(0, 5, {1,2,2,2,2}, {at::kCUDA}); diff --git a/tests/core/converters/test_unary.cpp b/tests/core/converters/test_unary.cpp index a82a4ade9f..78c4a68da3 100644 --- a/tests/core/converters/test_unary.cpp +++ b/tests/core/converters/test_unary.cpp @@ -1,6 +1,6 @@ #include #include "gtest/gtest.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/irparser.h" #include "tests/util/util.h" #include "core/compiler.h" @@ -19,7 +19,7 @@ std::string gen_test_graph(const std::string &unary) { const auto graph = gen_test_graph(#unary); \ \ auto g = std::make_shared(); \ - torch::jit::script::parseIR(graph, &*g); \ + torch::jit::parseIR(graph, &*g); \ \ auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \ auto params = \ diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 8be1047837..b3dc394478 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -2,13 +2,12 @@ import torchvision.models as models models = { - #"resnet18": models.resnet18(pretrained=True), "alexnet": models.alexnet(pretrained=True), "vgg16": models.vgg16(pretrained=True), "squeezenet": models.squeezenet1_0(pretrained=True), "densenet": models.densenet161(pretrained=True), - "inception": models.inception_v3(pretrained=True), - "googlenet": models.googlenet(pretrained=True), + "inception_v3": models.inception_v3(pretrained=True), + #"googlenet": models.googlenet(pretrained=True), "shufflenet": models.shufflenet_v2_x1_0(pretrained=True), "mobilenet_v2": models.mobilenet_v2(pretrained=True), "resnext50_32x4d": models.resnext50_32x4d(pretrained=True), @@ -18,7 +17,10 @@ "resnet50": torch.hub.load('pytorch/vision:v0.5.0', 'resnet50', pretrained=True)} for n, m in models.items(): + print("Downloading {}".format(n)) m = m.eval().cuda() x = torch.ones((1, 3, 224, 224)).cuda() - jit_model = torch.jit.trace(m, x) - torch.jit.save(jit_model, n + '.jit.pt') + trace_model = torch.jit.trace(m, x) + torch.jit.save(trace_model, n + '_traced.jit.pt') + script_model = torch.jit.script(m) + torch.jit.save(script_model, n + '_scripted.jit.pt') \ No newline at end of file diff --git a/tests/modules/module_test.h b/tests/modules/module_test.h index d7cd7e080b..61f8d0612a 100644 --- a/tests/modules/module_test.h +++ b/tests/modules/module_test.h @@ -4,6 +4,8 @@ #include "tests/util/util.h" #include "trtorch/trtorch.h" #include "c10/cuda/CUDACachingAllocator.h" +#include "cuda_runtime_api.h" + using PathAndInSize = std::pair>>; @@ -25,6 +27,7 @@ class ModuleTests } void TearDown() { + cudaDeviceSynchronize(); c10::cuda::CUDACachingAllocator::emptyCache(); } protected: diff --git a/tests/modules/test_compiled_modules.cpp b/tests/modules/test_compiled_modules.cpp index 199e1d81b5..9a5c9daf1d 100644 --- a/tests/modules/test_compiled_modules.cpp +++ b/tests/modules/test_compiled_modules.cpp @@ -28,9 +28,9 @@ TEST_P(ModuleTests, CompiledModuleIsClose) { INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite, ModuleTests, testing::Values( - PathAndInSize({"tests/modules/resnet18.jit.pt", + PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/resnet50.jit.pt", + PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", + PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1,3,224,224}}}))); diff --git a/tests/modules/test_modules_as_engines.cpp b/tests/modules/test_modules_as_engines.cpp index 759ed136c9..d190251bb3 100644 --- a/tests/modules/test_modules_as_engines.cpp +++ b/tests/modules/test_modules_as_engines.cpp @@ -19,9 +19,9 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) { INSTANTIATE_TEST_SUITE_P(ModuleAsEngineForwardIsCloseSuite, ModuleTests, testing::Values( - PathAndInSize({"tests/modules/resnet18.jit.pt", + PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/resnet50.jit.pt", + PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1,3,224,224}}}), - PathAndInSize({"tests/modules/mobilenet_v2.jit.pt", + PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1,3,224,224}}}))); \ No newline at end of file diff --git a/tests/modules/test_multiple_registered_engines.cpp b/tests/modules/test_multiple_registered_engines.cpp index 7ce3dbf61f..c03e68c0b4 100644 --- a/tests/modules/test_multiple_registered_engines.cpp +++ b/tests/modules/test_multiple_registered_engines.cpp @@ -8,8 +8,8 @@ TEST(ModuleTests, CanRunMultipleEngines) { torch::jit::script::Module mod1; torch::jit::script::Module mod2; try { - mod1 = torch::jit::load("tests/modules/resnet50.jit.pt"); - mod2 = torch::jit::load("tests/modules/resnet18.jit.pt"); + mod1 = torch::jit::load("tests/modules/resnet50_traced.jit.pt"); + mod2 = torch::jit::load("tests/modules/resnet18_traced.jit.pt"); } catch (const c10::Error& e) { std::cerr << "error loading the model\n"; diff --git a/tests/util/run_forward.cpp b/tests/util/run_forward.cpp index f8f22b900d..4806678a1f 100644 --- a/tests/util/run_forward.cpp +++ b/tests/util/run_forward.cpp @@ -6,13 +6,13 @@ namespace trtorch { namespace tests { namespace util { -torch::jit::IValue RunModuleForward(torch::jit::script::Module& mod, std::vector inputs) { +torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector inputs) { mod.to(at::kCUDA); return mod.forward(inputs); } -std::vector RunModuleForwardAsEngine(torch::jit::script::Module& mod, std::vector inputs) { +std::vector RunModuleForwardAsEngine(torch::jit::Module& mod, std::vector inputs) { auto forward_graph = mod.get_method("forward"); std::vector> input_ranges; for (auto in : inputs) { diff --git a/tests/util/run_graph.cpp b/tests/util/run_graph.cpp index 70f839c936..e749eb1e87 100644 --- a/tests/util/run_graph.cpp +++ b/tests/util/run_graph.cpp @@ -1,4 +1,4 @@ -#include "torch/csrc/jit/graph_executor.h" +#include "torch/csrc/jit/runtime/graph_executor.h" #include "tests/util/util.h" namespace trtorch { @@ -28,15 +28,15 @@ std::vector RunGraph(std::shared_ptr& g, } } - torch::jit::GraphExecutor executor(g); + torch::jit::GraphExecutor executor(g, ""); auto stack = CreateStack(std::move(inputs_)); - + executor.run(stack); std::vector outputs; for (size_t i = 0; i < g->outputs().size(); i++) { outputs.push_back(stack[i].toTensor()); } - + return outputs; } diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index a6a6070596..f9a3e3f388 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -1,8 +1,8 @@ #include "core/util/prelude.h" #include "NvInfer.h" #include "c10/cuda/CUDAStream.h" -#include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/irparser.h" +#include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/irparser.h" #include "core/conversion/conversion.h" #include "cuda_runtime_api.h" @@ -28,34 +28,35 @@ std::vector RunEngine(std::string& eng, std::vector inpu std::vector contig_inputs{}; contig_inputs.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { + TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); + auto expected_type = core::util::toATenDType(ctx->getEngine().getBindingDataType(i)); + TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); auto dims = core::util::toDimsPad(inputs[i].sizes(), 1); auto shape = core::util::toVec(dims); - contig_inputs.push_back(inputs[i].to(at::kCUDA).view(shape).contiguous()); + contig_inputs.push_back(inputs[i].view(shape).contiguous()); LOG_DEBUG("In shape:" << shape); ctx->setBindingDimensions(i, dims); gpu_handles.push_back(contig_inputs.back().data_ptr()); } - if (!ctx->allInputDimensionsSpecified()) { - LOG_ERROR("Not enough inputs provided (tests.runEngine)"); - return {}; - } + TRTORCH_CHECK(ctx->allInputDimensionsSpecified(), "Not enough inputs provided (execution.RunCudaEngine)"); std::vector outputs; - for (int o = inputs.size(); o < engine->getNbBindings(); o++) { + for (int64_t o = inputs.size(); o < engine->getNbBindings(); o++) { auto out_shape = ctx->getBindingDimensions(o); LOG_DEBUG("Output: " << engine->getBindingName(o) << " out shape: " << out_shape); auto dims = core::util::toVec(out_shape); - outputs.push_back(at::empty(dims, {at::kCUDA}).contiguous()); + auto type = core::util::toATenDType(ctx->getEngine().getBindingDataType(o)); + outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous()); gpu_handles.push_back(outputs[outputs.size() - 1].data_ptr()); } - - c10::cuda::CUDAStream stream = c10::cuda::getStreamFromPool(true, 0); + // Is this the right stream? + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index()); ctx->enqueueV2(gpu_handles.data(), stream, nullptr); - stream.synchronize(); + stream.synchronize(); return outputs; } diff --git a/tests/util/util.h b/tests/util/util.h index fa9008bd39..6ae72b3370 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -29,11 +29,11 @@ std::vector RunGraphEngine(std::shared_ptr& g, std::vector inputs); // Run the forward method of a module and return results -torch::jit::IValue RunModuleForward(torch::jit::script::Module& mod, +torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector inputs); // Convert the forward module to a TRT engine and return results -std::vector RunModuleForwardAsEngine(torch::jit::script::Module& mod, +std::vector RunModuleForwardAsEngine(torch::jit::Module& mod, std::vector inputs); diff --git a/third_party/BUILD b/third_party/BUILD new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/cuda/BUILD b/third_party/cuda/BUILD index 651964393c..758a81d0db 100644 --- a/third_party/cuda/BUILD +++ b/third_party/cuda/BUILD @@ -3,7 +3,7 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "cudart", srcs = glob([ - "lib/**/libcudart.so.10.1", + "lib/**/libcudart.so.10.2", ]), hdrs = glob([ "include/**/*.h", diff --git a/third_party/cudnn/BUILD b/third_party/cudnn/archive/BUILD similarity index 82% rename from third_party/cudnn/BUILD rename to third_party/cudnn/archive/BUILD index e6423b6f20..027a1a6bcc 100644 --- a/third_party/cudnn/BUILD +++ b/third_party/cudnn/archive/BUILD @@ -6,17 +6,17 @@ cc_library( includes = ["include/"], visibility = ["//visibility:private"], ) - + cc_import( name = "cudnn_lib", - shared_library = "lib/x86_64-linux-gnu/libcudnn.so", + shared_library = "lib64/libcudnn.so.7.6.5", visibility = ["//visibility:private"], ) cc_library( name = "cudnn", deps = [ - "cudnn_headers", + "cudnn_headers", "cudnn_lib" ], visibility = ["//visibility:public"], diff --git a/third_party/cudnn/local/BUILD b/third_party/cudnn/local/BUILD new file mode 100644 index 0000000000..3af56fa61f --- /dev/null +++ b/third_party/cudnn/local/BUILD @@ -0,0 +1,25 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cudnn_headers", + hdrs = ["include/cudnn.h"] + glob(["include/cudnn+.h"]), + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_import( + name = "cudnn_lib", + shared_library = "lib/x86_64-linux-gnu/libcudnn.so.7.6.5, + visibility = ["//visibility:private"], +) + +cc_library( + name = "cudnn", + deps = [ + "cudnn_headers", + "cudnn_lib" + ], + visibility = ["//visibility:public"], +) + + diff --git a/third_party/libtorch/BUILD b/third_party/libtorch/BUILD index 0c5bff0d2f..9da5da47c0 100644 --- a/third_party/libtorch/BUILD +++ b/third_party/libtorch/BUILD @@ -18,11 +18,17 @@ cc_library( ) + glob([ 'include/torch/csrc/api/include/**/*.h' ]), - srcs = ['lib/libtorch.so'], + srcs = [ + 'lib/libtorch.so', + 'lib/libtorch_cuda.so', + 'lib/libtorch_cpu.so', + 'lib/libtorch_global_deps.so', + ], deps = [ ":ATen", ":torch_deps", ":c10_cuda", + ":caffe2" ], includes = [ "include", @@ -34,7 +40,7 @@ cc_library( name = 'torch_deps', srcs = [ 'lib/libnvToolsExt-3965bdd0.so.1', - "lib/libcudart-1b201d85.so.10.1", + "lib/libcudart-80664282.so.10.2", ] ) @@ -66,8 +72,7 @@ cc_library( cc_library( name = "c10_deps", srcs = [ - "lib/libgomp-753e6e92.so.1", - "lib/libcudart-1b201d85.so.10.1", + "lib/libcudart-80664282.so.10.2", ] ) @@ -84,7 +89,12 @@ cc_library( hdrs = glob([ 'include/caffe2/**/*.h' ]), - srcs = ['lib/libcaffe2_nvrtc.so'], + srcs = [ + 'lib/libcaffe2_nvrtc.so', + 'lib/libcaffe2_detectron_ops_gpu.so', + 'lib/libcaffe2_observers.so', + 'lib/libcaffe2_module_test_dynamic.so' + ], strip_include_prefix = "include", deps = [ ":caffe2_deps" @@ -95,6 +105,5 @@ cc_library( name = 'caffe2_deps', srcs = [ 'lib/libnvToolsExt-3965bdd0.so.1', - 'lib/libgomp-753e6e92.so.1' ] ) diff --git a/third_party/tensorrt/archive/BUILD b/third_party/tensorrt/archive/BUILD new file mode 100644 index 0000000000..2dc2959f79 --- /dev/null +++ b/third_party/tensorrt/archive/BUILD @@ -0,0 +1,167 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "nvinfer_headers", + hdrs = glob([ + "include/NvInfer*.h", + ], exclude=[ + "include/NvInferPlugin.h", + "include/NvInferPluginUtils.h" + ]), + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_import( + name = "nvinfer_lib", + shared_library = "lib/libnvinfer.so.7.0.0", + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvinfer", + deps = [ + "nvinfer_headers", + "nvinfer_lib", + "@cuda//:cudart", + "@cuda//:cublas", + "@cudnn//:cudnn" + ], + visibility = ["//visibility:public"], +) + +#################################################################################### + +cc_import( + name = "nvparsers_lib", + shared_library = "lib/libnvparsers.so.7.0.0", + visibility = ["//visibility:private"], +) + + +cc_library( + name = "nvparsers_headers", + hdrs = [ + "include/NvCaffeParser.h", + "include/NvOnnxParser.h", + "include/NvOnnxParserRuntime.h", + "include/NvOnnxConfig.h", + "include/NvUffParser.h" + ], + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvparsers", + deps = [ + "nvparsers_headers", + "nvparsers_lib", + "nvinfer" + ], + visibility = ["//visibility:public"], +) + +#################################################################################### + +cc_import( + name = "nvonnxparser_lib", + shared_library = "lib/libnvonnxparser.so.7.0.0", + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvonnxparser_headers", + hdrs = [ + "include/NvOnnxParser.h", + "include/NvOnnxParserRuntime.h", + "include/NvOnnxConfig.h" + ], + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvonnxparser", + deps = [ + "nvonnxparser_headers", + "nvonnxparser_lib", + "nvinfer" + ], + visibility = ["//visibility:public"], +) + +#################################################################################### + +cc_import( + name = "nvonnxparser_runtime_lib", + shared_library = "lib/libnvonnxparser_runtime.so.7.0.0", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nvonnxparser_runtime_header", + hdrs = ["include/NvOnnxParserRuntime.h"], + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvonnxparser_runtime", + deps = [ + "nvparsers_headers", + "nvparsers_lib", + "nvinfer" + ], + visibility = ["//visibility:public"], +) + +#################################################################################### + +cc_import( + name = "nvcaffeparser_lib", + shared_library = "lib/libnvcaffe_parsers.so.7.0.0", + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvcaffeparser_headers", + hdrs = ["include/NvCaffeParser.h"], + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvcaffeparser", + deps = [ + "nvcaffeparser_headers", + "nvcaffeparser_lib", + "nvinfer" + ], + visibility = ["//visibility:public"], +) + +#################################################################################### + +cc_import( + name = "nvinferplugin_lib", + shared_library = "lib/libnvinfer_plugin.so.7.0.0", + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvinferplugin_headers", + hdrs = glob(["include/NvInferPlugin*.h"]), + includes = ["include/"], + visibility = ["//visibility:private"], +) + +cc_library( + name = "nvinferplugin", + deps = [ + "nvinferplugin_headers", + "nvinferplugin_lib", + "nvinfer" + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/tensorrt/BUILD b/third_party/tensorrt/local/BUILD similarity index 100% rename from third_party/tensorrt/BUILD rename to third_party/tensorrt/local/BUILD