diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index cbe6da31addf2..a46987a554b2a 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1368,6 +1368,44 @@ struct ReOpLowering : public ConvertOpToLLVMPattern { } }; +struct TanhOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (operands.size() != 1) + return failure(); + Type resultType = op->getResult(0).getType(); + const char* funcName; + LLVM::LLVMType llvmResultType; + if (resultType.isF32()) { + funcName = static_cast("tanhf"); + llvmResultType = LLVM::LLVMType::getFloatTy(&getDialect()); + } + else if (resultType.isF64()) { + funcName = static_cast("tanh"); + llvmResultType = LLVM::LLVMType::getDoubleTy(&getDialect()); + } + else + return failure(); + + // Insert the appropriate tanh declaration if it is not already present. + auto tanhFunc = + op->getParentOfType().lookupSymbol(funcName); + if (!tanhFunc) { + OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); + tanhFunc = moduleBuilder.create( + op->getLoc(), funcName, + LLVM::LLVMType::getFunctionTy(llvmResultType, {llvmResultType}, + /*isVarArg=*/false)); + } + rewriter.replaceOpWithNewOp( + op, ArrayRef(llvmResultType), rewriter.getSymbolRefAttr(tanhFunc), operands); + return success(); + } +}; + struct ImOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -2977,6 +3015,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns( SubCFOpLowering, SubFOpLowering, SubIOpLowering, + TanhOpLowering, TruncateIOpLowering, UnsignedDivIOpLowering, UnsignedRemIOpLowering, @@ -3147,7 +3186,6 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); - this->addIllegalOp(); } std::unique_ptr> diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir index b8ebdfbf35f1c..948a9bd1639c2 100644 --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -87,3 +87,23 @@ func @unknown_source() -> i32 { // expected-error@+1 {{must be LLVM dialect type}} return %1 : i32 } + +// ----- + +// CHECK-LABEL: @tanh_float +func @tanh_float() { + %c0 = constant 1.0 : f32 + // CHECK: %[[.*]] = llvm.call @tanh(%[[.*]]) : (!llvm.float) -> !llvm.float + %1 = tanh %c0 : f32 + return +} + +// ----- + +// CHECK-LABEL: @tanh_double +func @tanh_double() { + %c0 = constant 1.0 : f64 + // CHECK: %[[.*]] = llvm.call @tanh(%[[.*]]) : (!llvm.double) -> !llvm.double + %1 = tanh %c0 : f64 + return +}