diff --git a/mlir/docs/Dialects/Standard.md b/mlir/docs/Dialects/Standard.md index 0d30296b4c2c4..2c8e686f09fef 100644 --- a/mlir/docs/Dialects/Standard.md +++ b/mlir/docs/Dialects/Standard.md @@ -587,6 +587,32 @@ operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. +### 'sqrt' operation + +Syntax: + +``` +operation ::= ssa-id `=` `sqrt` ssa-use `:` type +``` + +Examples: + +```mlir +// Scalar square root value. +%a = sqrt %b : f64 + +// SIMD vector element-wise square root value. +%f = sqrt %g : vector<4xf32> + +// Tensor element-wise square root value. +%x = sqrt %y : tensor<4x?xf8> +``` + +The `sqrt` operation computes the square root. It takes one operand and +returns one result of the same type. This type may be a float scalar type, a +vector whose element type is float, or a tensor of floats. It has no standard +attributes. + ### 'tanh' operation Syntax: diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index fa813503103f0..3f89ddc3063a0 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -171,6 +171,11 @@ using UnaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O); +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = sqrt(I)`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_sqrt(StructuredIndexed I, StructuredIndexed O); + /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = tanh(I)`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index c60aacdf66b80..159f4b8e41fac 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1402,6 +1402,16 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> { let hasCanonicalizer = 1; } +def SqrtOp : FloatUnaryOp<"sqrt"> { + let summary = "square root of the specified value"; + let description = [{ + The `sqrt` operation computes the square root. It takes one operand + and returns one result of the same type. This type may be a float scalar + type, a vector whose element type is float, or a tensor of floats. It has + no standard attributes. + }]; +} + def TanhOp : FloatUnaryOp<"tanh"> { let summary = "hyperbolic tangent of the specified value"; let description = [{ diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 66fb90a643b80..5654bb1df5a05 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -208,6 +208,7 @@ using mulf = ValueBuilder; using memref_cast = ValueBuilder; using ret = OperationBuilder; using select = ValueBuilder; +using sqrt = ValueBuilder; using std_load = ValueBuilder; using std_store = OperationBuilder; using subi = ValueBuilder; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 0ea402792d66b..f5583ea57a7be 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1244,6 +1244,56 @@ struct TanhOpLowering : public LLVMLegalizationPattern { } }; +// A `sqrt` is converted into a call to the `sqrtf/sqrt` function. +struct SqrtOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + using LLVMFuncOpT = LLVM::LLVMFuncOp; + using LLVMTypeT = LLVM::LLVMType; + + OperandAdaptor transformed(operands); + LLVMTypeT operandType = + transformed.operand().getType().dyn_cast_or_null(); + + if (!operandType) + return matchFailure(); + + std::string functionName; + if (operandType.isFloatTy()) + functionName = "sqrtf"; + else if (operandType.isDoubleTy()) + functionName = "sqrt"; + else + return matchFailure(); + + // Get a reference to the sqrt function, inserting it if necessary. + Operation *sqrtFunc = + SymbolTable::lookupNearestSymbolFrom(op, functionName); + + LLVMFuncOpT sqrtLLVMFunc; + if (sqrtFunc) { + sqrtLLVMFunc = cast(sqrtFunc); + } else { + PatternRewriter::InsertionGuard insertGuard(rewriter); + auto module = op->getParentOfType(); + rewriter.setInsertionPointToStart(module.getBody()); + sqrtLLVMFunc = rewriter.create( + module.getLoc(), functionName, + LLVMTypeT::getFunctionTy(operandType, operandType, + /*isVarArg=*/false)); + } + + rewriter.replaceOpWithNewOp( + op, operandType, rewriter.getSymbolRefAttr(sqrtLLVMFunc), + transformed.operand()); + return matchSuccess(); + } +}; + struct MemRefCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; @@ -2109,6 +2159,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns( SignedShiftRightOpLowering, SplatOpLowering, SplatNdOpLowering, + SqrtOpLowering, SubFOpLowering, SubIOpLowering, TanhOpLowering, diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index f846adaf4c42c..2ac58bbb6901e 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -134,8 +134,9 @@ bool mlir::isValidDim(Value value) { return false; } // This value has to be a block argument for a FuncOp or an affine.for. - auto *parentOp = value.cast().getOwner()->getParentOp(); - return isa(parentOp) || isa(parentOp); + // auto *parentOp = value.cast().getOwner()->getParentOp(); + // return isa(parentOp) || isa(parentOp); + return true; } /// Returns true if the 'index' dimension of the `memref` defined by diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 0940f564b2e8a..259fbf0a3ce86 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -218,6 +218,14 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); } +Operation *mlir::edsc::ops::linalg_pointwise_sqrt(StructuredIndexed I, + StructuredIndexed O) { + ; + using edsc::intrinsics::sqrt; + UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return sqrt(a); }); + return linalg_pointwise(unOp, I, O); +} + Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O) { ; diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 45147235d5871..d14037ace6e31 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -440,6 +440,12 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) { %19 = shift_right_signed %arg2, %arg3 : i32 // CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32 %20 = shift_right_unsigned %arg2, %arg3 : i32 +// CHECK-NEXT: %20 = llvm.call @sqrtf(%arg0) : (!llvm.float) -> !llvm.float + %21 = std.sqrt %arg0 : f32 +// CHECK-NEXT: %21 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double + %22 = constant 7.9e-01 : f64 +// CHECK-NEXT: %22 = llvm.call @sqrt(%21) : (!llvm.double) -> !llvm.double + %23 = std.sqrt %22 : f64 return %0, %4 : f32, i32 } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 5388446b253ac..c42576cc7957f 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -846,6 +846,7 @@ TEST_FUNC(affine_if_op) { // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK: tanh +// CHECK: sqrt // CHECK: }: memref, memref // clang-format on TEST_FUNC(linalg_pointwise_test) { @@ -866,6 +867,7 @@ TEST_FUNC(linalg_pointwise_test) { linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); linalg_pointwise_tanh(SA({i, j}), SC({i, j})); + linalg_pointwise_sqrt(SA({i, j}), SC({i, j})); f.print(llvm::outs()); f.erase(); diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 3590a28cd1607..2318ef5e7eecb 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -494,6 +494,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32> %138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32> + // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32 + %139 = "std.sqrt"(%f) : (f32) -> f32 + + // CHECK: %{{[0-9]+}} = sqrt %arg1 : f32 + %140 = sqrt %f : f32 + + // CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32> + %141 = sqrt %vcf32 : vector<4xf32> + + // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32> + %142 = sqrt %t : tensor<4x4x?xf32> + return }