Skip to content

Llitchev mlir add sqrt #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/docs/Dialects/Standard.md
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,32 @@ operand and returns one result of the same type. This type may be a float
scalar type, a vector whose element type is float, or a tensor of floats. It
has no standard attributes.

### 'sqrt' operation

Syntax:

```
operation ::= ssa-id `=` `sqrt` ssa-use `:` type
```

Examples:

```mlir
// Scalar square root value.
%a = sqrt %b : f64

// SIMD vector element-wise square root value.
%f = sqrt %g : vector<4xf32>

// Tensor element-wise square root value.
%x = sqrt %y : tensor<4x?xf8>
```

The `sqrt` operation computes the square root. It takes one operand and
returns one result of the same type. This type may be a float scalar type, a
vector whose element type is float, or a tensor of floats. It has no standard
attributes.

### 'tanh' operation

Syntax:
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>;
Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I, StructuredIndexed O);

/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = sqrt(I)`. The client is responsible for specifying the proper
/// indexings when creating the StructuredIndexed.
Operation *linalg_pointwise_sqrt(StructuredIndexed I, StructuredIndexed O);

/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = tanh(I)`. The client is responsible for specifying the proper
/// indexings when creating the StructuredIndexed.
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,16 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
let hasCanonicalizer = 1;
}

def SqrtOp : FloatUnaryOp<"sqrt"> {
let summary = "square root of the specified value";
let description = [{
The `sqrt` operation computes the square root. It takes one operand
and returns one result of the same type. This type may be a float scalar
type, a vector whose element type is float, or a tensor of floats. It has
no standard attributes.
}];
}

def TanhOp : FloatUnaryOp<"tanh"> {
let summary = "hyperbolic tangent of the specified value";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/EDSC/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ using mulf = ValueBuilder<MulFOp>;
using memref_cast = ValueBuilder<MemRefCastOp>;
using ret = OperationBuilder<ReturnOp>;
using select = ValueBuilder<SelectOp>;
using sqrt = ValueBuilder<SqrtOp>;
using std_load = ValueBuilder<LoadOp>;
using std_store = OperationBuilder<StoreOp>;
using subi = ValueBuilder<SubIOp>;
Expand Down
51 changes: 51 additions & 0 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,56 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
}
};

// A `sqrt` is converted into a call to the `sqrtf/sqrt` function.
struct SqrtOpLowering : public LLVMLegalizationPattern<SqrtOp> {
using LLVMLegalizationPattern<SqrtOp>::LLVMLegalizationPattern;

PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {

using LLVMFuncOpT = LLVM::LLVMFuncOp;
using LLVMTypeT = LLVM::LLVMType;

OperandAdaptor<SqrtOp> transformed(operands);
LLVMTypeT operandType =
transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();

if (!operandType)
return matchFailure();

std::string functionName;
if (operandType.isFloatTy())
functionName = "sqrtf";
else if (operandType.isDoubleTy())
functionName = "sqrt";
else
return matchFailure();

// Get a reference to the sqrt function, inserting it if necessary.
Operation *sqrtFunc =
SymbolTable::lookupNearestSymbolFrom(op, functionName);

LLVMFuncOpT sqrtLLVMFunc;
if (sqrtFunc) {
sqrtLLVMFunc = cast<LLVMFuncOpT>(sqrtFunc);
} else {
PatternRewriter::InsertionGuard insertGuard(rewriter);
auto module = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPointToStart(module.getBody());
sqrtLLVMFunc = rewriter.create<LLVMFuncOpT>(
module.getLoc(), functionName,
LLVMTypeT::getFunctionTy(operandType, operandType,
/*isVarArg=*/false));
}

rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, operandType, rewriter.getSymbolRefAttr(sqrtLLVMFunc),
transformed.operand());
return matchSuccess();
}
};

struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;

Expand Down Expand Up @@ -2109,6 +2159,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
SignedShiftRightOpLowering,
SplatOpLowering,
SplatNdOpLowering,
SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
TanhOpLowering,
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/AffineOps/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ bool mlir::isValidDim(Value value) {
return false;
}
// This value has to be a block argument for a FuncOp or an affine.for.
auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
// auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
// return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
return true;
}

/// Returns true if the 'index' dimension of the `memref` defined by
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
}

Operation *mlir::edsc::ops::linalg_pointwise_sqrt(StructuredIndexed I,
StructuredIndexed O) {
;
using edsc::intrinsics::sqrt;
UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return sqrt(a); });
return linalg_pointwise(unOp, I, O);
}

Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O) {
;
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
%19 = shift_right_signed %arg2, %arg3 : i32
// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
%20 = shift_right_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %20 = llvm.call @sqrtf(%arg0) : (!llvm.float) -> !llvm.float
%21 = std.sqrt %arg0 : f32
// CHECK-NEXT: %21 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
%22 = constant 7.9e-01 : f64
// CHECK-NEXT: %22 = llvm.call @sqrt(%21) : (!llvm.double) -> !llvm.double
%23 = std.sqrt %22 : f64

return %0, %4 : f32, i32
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/EDSC/builder-api-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ TEST_FUNC(affine_if_op) {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK: tanh
// CHECK: sqrt
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>
// clang-format on
TEST_FUNC(linalg_pointwise_test) {
Expand All @@ -866,6 +867,7 @@ TEST_FUNC(linalg_pointwise_test) {
linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_pointwise_tanh(SA({i, j}), SC({i, j}));
linalg_pointwise_sqrt(SA({i, j}), SC({i, j}));

f.print(llvm::outs());
f.erase();
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/IR/core-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32>
%138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32>

// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
%139 = "std.sqrt"(%f) : (f32) -> f32

// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
%140 = sqrt %f : f32

// CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32>
%141 = sqrt %vcf32 : vector<4xf32>

// CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
%142 = sqrt %t : tensor<4x4x?xf32>

return
}

Expand Down