Skip to content

Commit 67b9d3f

Browse files
authored
[mlir] computeSliceParameters: Fix offset when m(0) != 0 (llvm#122492)
For affine maps where `m(0) != 0`, like `affine_map<(d0) -> (d0 + 3)` in ``` %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0 + 3)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0: tensor<9xf32>) outs(%empty : tensor<6xf32>) { ^bb0(%in : f32, %out: f32): linalg.yield %in : f32 } -> tensor<6xf32> ``` tiling currently computes the wrong slice offsets. When tiling above example with a size of 3, it would compute ``` scf.for %i = ... %slice = tensor.extract_slice %arg0[%i + 3] [6] [1] linalg.generic {indexing_maps = [affine_map<(d0) -> (d0 + 3)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%slice: tensor<6xf32>) ``` and thus apply the `+3` twice (once in the extract slice and a second time in the linalg.generic). This PR fixes this to yield an offset of `tensor.extract_slice %arg0[%i] [6] [1]` instead.
1 parent 2a8c12b commit 67b9d3f

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,17 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
596596
auto m = map.getSubMap({r});
597597
LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
598598
IRRewriter rewriter(builder);
599-
OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
599+
// The offset of the slice is m(lbs) - m(0).
600+
SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0));
601+
SmallVector<Attribute> mAtZero;
602+
[[maybe_unused]] auto res = m.constantFold(zeros, mAtZero);
603+
assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
604+
int64_t mAtZeroInt =
605+
cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
606+
OpFoldResult offset = makeComposedFoldedAffineApply(
607+
rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
600608
sliceParams.offsets.push_back(offset);
609+
601610
OpFoldResult closedIntSize =
602611
makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
603612
// Resulting size needs to be made half open interval again.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func @tile_offset
4+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]:
5+
func.func @tile_offset(%arg0 : tensor<9xf32>) -> tensor<6xf32> {
6+
%empty = tensor.empty() : tensor<6xf32>
7+
// CHECK: scf.for %[[ITER:[a-zA-Z0-9_]+]] =
8+
// CHECK: tensor.extract_slice %[[ARG0]][%[[ITER]]] [6] [1]
9+
%generic = linalg.generic
10+
{indexing_maps = [affine_map<(d0) -> (d0 + 3)>,
11+
affine_map<(d0) -> (d0)>],
12+
iterator_types = ["parallel"]} ins(%arg0: tensor<9xf32>) outs(%empty : tensor<6xf32>) {
13+
^bb0(%in : f32, %out: f32):
14+
linalg.yield %in : f32
15+
} -> tensor<6xf32>
16+
return %generic : tensor<6xf32>
17+
}
18+
19+
module attributes {transform.with_named_sequence} {
20+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
21+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
22+
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
23+
transform.yield
24+
}
25+
}

0 commit comments

Comments
 (0)