Skip to content

Commit ce561f3

Browse files
mfeliz-cruisebowang007
authored andcommitted
feat: Add converter for aten::log2 (#1866)
1 parent 71f61c6 commit ce561f3

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

core/conversion/converters/impl/unary.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,36 @@ auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().
3434
return true;
3535
}});
3636

37+
auto log2_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
38+
{"aten::log2(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
39+
const static float ln2 = 0.693147180559945309; // same constant onnx uses
40+
auto in = args[0].ITensorOrFreeze(ctx);
41+
auto tensor_type = util::TRTDataTypeToScalarType(in->getType());
42+
if (in->getType() == nvinfer1::DataType::kINT32) {
43+
// pytorch implicitly casts to float for aten::log2(int)
44+
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
45+
tensor_type = at::kFloat;
46+
}
47+
48+
auto log_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kLOG);
49+
TORCHTRT_CHECK(log_layer, "Unable to create log layer from node: " << *n);
50+
log_layer->setName((util::node_info(n) + "_log").c_str());
51+
52+
std::vector<int64_t> ln2_dims(in->getDimensions().nbDims, 1);
53+
auto ln2_tensor = at::full(ln2_dims, ln2, at::TensorOptions().dtype(tensor_type));
54+
auto ln2_itensor = converters::tensor_to_const(ctx, ln2_tensor);
55+
56+
auto div_layer = add_elementwise(
57+
ctx,
58+
nvinfer1::ElementWiseOperation::kDIV,
59+
log_layer->getOutput(0),
60+
ln2_itensor,
61+
(util::node_info(n) + "_div").c_str());
62+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], div_layer->getOutput(0));
63+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
64+
return true;
65+
}});
66+
3767
auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
3868
{"aten::logical_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3969
auto in = args[0].ITensorOrFreeze(ctx);

tests/core/conversion/converters/test_unary.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ TEST(Converters, ATenReciprocalIntConvertsCorrectly) {
4747
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
4848
}
4949

50+
TEST(Converters, ATenLog2IntConvertsCorrectly) {
51+
const auto graph = gen_test_graph("log2");
52+
auto g = std::make_shared<torch::jit::Graph>();
53+
torch::jit::parseIR(graph, g.get());
54+
55+
auto in = at::tensor({1, 2, 7, 25, 50}, {at::kCUDA}).to(torch::kInt32);
56+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
57+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
58+
59+
in = at::clone(in);
60+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
61+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
62+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
63+
}
64+
5065
TEST(Converters, ATenSignConvertsCorrectly) {
5166
const auto graph = gen_test_graph("sign");
5267
auto g = std::make_shared<torch::jit::Graph>();
@@ -129,6 +144,7 @@ test_unary(abs, Abs);
129144
test_unary(floor, Floor);
130145
test_unary(reciprocal, Reciprocal);
131146
test_unary(log, Log);
147+
test_unary(log2, Log2);
132148
test_unary(ceil, Ceil);
133149
test_unary(sqrt, Sqrt);
134150
test_unary(exp, Exp);

0 commit comments

Comments
 (0)