@@ -34,6 +34,36 @@ auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().
34
34
return true ;
35
35
}});
36
36
37
+ auto log2_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
38
+ {" aten::log2(Tensor self) -> Tensor" , [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
39
+ const static float ln2 = 0.693147180559945309 ; // same constant onnx uses
40
+ auto in = args[0 ].ITensorOrFreeze (ctx);
41
+ auto tensor_type = util::TRTDataTypeToScalarType (in->getType ());
42
+ if (in->getType () == nvinfer1::DataType::kINT32 ) {
43
+ // pytorch implicitly casts to float for aten::log2(int)
44
+ in = castITensor (ctx, in, nvinfer1::DataType::kFLOAT );
45
+ tensor_type = at::kFloat ;
46
+ }
47
+
48
+ auto log_layer = ctx->net ->addUnary (*in, nvinfer1::UnaryOperation::kLOG );
49
+ TORCHTRT_CHECK (log_layer, " Unable to create log layer from node: " << *n);
50
+ log_layer->setName ((util::node_info (n) + " _log" ).c_str ());
51
+
52
+ std::vector<int64_t > ln2_dims (in->getDimensions ().nbDims , 1 );
53
+ auto ln2_tensor = at::full (ln2_dims, ln2, at::TensorOptions ().dtype (tensor_type));
54
+ auto ln2_itensor = converters::tensor_to_const (ctx, ln2_tensor);
55
+
56
+ auto div_layer = add_elementwise (
57
+ ctx,
58
+ nvinfer1::ElementWiseOperation::kDIV ,
59
+ log_layer->getOutput (0 ),
60
+ ln2_itensor,
61
+ (util::node_info (n) + " _div" ).c_str ());
62
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], div_layer->getOutput (0 ));
63
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
64
+ return true ;
65
+ }});
66
+
37
67
auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
38
68
{" aten::logical_not(Tensor self) -> Tensor" , [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
39
69
auto in = args[0 ].ITensorOrFreeze (ctx);
0 commit comments