Skip to content

fix: Bugfix in convNd_to_convolution lowering pass #1693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions core/lowering/passes/convNd_to_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include "torch/csrc/jit/ir/irparser.h"

#include "core/util/prelude.h"

Expand All @@ -7,78 +8,91 @@ namespace core {
namespace lowering {
namespace passes {

void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv1d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
void replaceConv(
torch::jit::Block* block,
const std::string& node_kind,
const std::string& unwrapped_conv,
const size_t num_input_args) {
// Iterate through nodes in block, seaching for aten::conv*
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;

// Recursively explore nested blocks, such as those arising from prim::If
for (auto nested_block : n->blocks()) {
replaceConv(nested_block, node_kind, unwrapped_conv, num_input_args);
}

// If node matches desired kind and number of input arguments, replace it
if ((n->kind().toQualString() == node_kind) && (n->inputs().size() == num_input_args)) {
// Establish insert point within block
torch::jit::WithInsertPoint guard(*it);

// Initialize new fused subgraph from IR code provided
auto fused_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(unwrapped_conv, fused_g.get());

// Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
}
}
}

std::string convolution_pattern = R"IR(
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
const std::string conv1d_node_kind = "aten::conv1d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

torch::jit::SubgraphRewriter map_conv1d_to_convolution;
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
map_conv1d_to_convolution.runOnGraph(graph);
// Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
replaceConv(graph->block(), conv1d_node_kind, convolution_pattern, 7);
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
}

void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv_transpose1d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %o, %g, %d):
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
return (%4))IR";
std::string convolution_pattern = R"IR(
const std::string conv_transpose1d_node_kind = "aten::conv_transpose1d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %o, %g, %d):
%1 : bool = prim::Constant[value=1]()
%2 : bool = prim::Constant[value=1]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
return (%4))IR";

torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
map_conv_transpose1d_to_convolution.runOnGraph(graph);
// Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
replaceConv(graph->block(), conv_transpose1d_node_kind, convolution_pattern, 8);
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
}

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv2d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
const std::string conv2d_node_kind = "aten::conv2d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
map_conv2d_to_convolution.runOnGraph(graph);
// Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
replaceConv(graph->block(), conv2d_node_kind, convolution_pattern, 7);
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
}

void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv3d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
const std::string conv3d_node_kind = "aten::conv3d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
map_conv3d_to_convolution.runOnGraph(graph);
// Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
replaceConv(graph->block(), conv3d_node_kind, convolution_pattern, 7);
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
}

Expand Down
83 changes: 83 additions & 0 deletions tests/core/lowering/test_conv1d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/constant_pooling.h"

TEST(LoweringPasses, Conv1dCorrectly) {
const auto source_graph = R"IR(
Expand Down Expand Up @@ -117,3 +119,84 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
}

TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
std::string source_graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%stride : int[] = prim::ListConstruct(%6)
%padding : int[] = prim::ListConstruct(%4)
%dilation : int[] = prim::ListConstruct(%5)

# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
%12 : Tensor = prim::If(%true)
block0():
%res: Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
-> (%res)
block1():
%res: Tensor = aten::conv1d(%invalid_weight, %1, %2, %stride, %padding, %dilation, %6)
-> (%res)
return (%12))IR";

std::string target_graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%true : bool = prim::Constant[value=1]()
%3 : bool = prim::Constant[value=0]()
%output_padding : int[] = prim::Constant[value=[0]]()
%6 : int = prim::Constant[value=1]()
%stride : int[] = prim::ListConstruct(%6)
%padding : int[] = prim::ListConstruct(%4)
%dilation : int[] = prim::ListConstruct(%5)

# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
%12 : Tensor = prim::If(%true)
block0():
%res: Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
-> (%res)
block1():
%res: Tensor = aten::_convolution(%invalid_weight, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
-> (%res)
return (%12))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));

auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
auto b = at::randint(1, 10, {4}, {at::kCUDA});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});

params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
}
3 changes: 2 additions & 1 deletion tests/core/partitioning/test_conditionals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
auto g = mod.get_method("forward").graph();
torch_tensorrt::core::CompileSpec cfg(inputs);
cfg.partitioning_info.enabled = true;
cfg.partitioning_info.forced_fallback_operators.push_back("aten::log_sigmoid");
torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
auto new_g = new_mod.get_method("forward").graph();

auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);

ASSERT_TRUE(conditional_engines_count == 1);
ASSERT_TRUE(conditional_engines_count == 2);
}

TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
Expand Down