Skip to content

Commit 5c6db8c

Browse files
authored
[MLIR] TosaToLinalg: Prefer to emit identity maps (#123295)
When deciding whether to emit a map like `#map = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>` or `#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` for an operand of a `linalg.generic` when lowering element-wise TOSA ops, prefer the latter unless broadcasting of the operand is really needed. This helps later transformations which often require the affine map to be a projected permuatation.
1 parent 046b064 commit 5c6db8c

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
882882
auto shape = cast<ShapedType>(operand.getType()).getShape();
883883
SmallVector<AffineExpr> affineExprs;
884884
for (auto it : llvm::enumerate(shape)) {
885-
auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
886-
: rewriter.getAffineDimExpr(it.index());
885+
// Prefer producting identity maps whenever possible (i.e. no broadcasting
886+
// needed) because some transforms (like reshape folding)
887+
// do not support affine constant exprs.
888+
bool requiresBroadcast =
889+
(it.value() == 1 && resultType.getDimSize(it.index()) != 1);
890+
auto affineExpr = requiresBroadcast
891+
? rewriter.getAffineConstantExpr(0)
892+
: rewriter.getAffineDimExpr(it.index());
887893
affineExprs.push_back(affineExpr);
888894
}
889895
return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,15 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
100100

101101
// -----
102102

103-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
104-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
105-
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
103+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
104+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)>
106105

107106
// CHECK-LABEL: func.func @test_add_0d_broadcast(
108107
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
109108
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
110109
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
111110
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
112-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
111+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
113112
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
114113
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
115114
// CHECK: linalg.yield %[[ADD]] : f32
@@ -253,6 +252,26 @@ func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: t
253252

254253
// -----
255254

255+
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
256+
// CHECK-LABEL: @test_add_1d_matching_no_broadcast
257+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
258+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
259+
func.func @test_add_1d_matching_no_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
260+
261+
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1xf32>
262+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_0]] : tensor<1xf32>) {
263+
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
264+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
265+
// CHECK: linalg.yield %[[VAL_4]] : f32
266+
// CHECK: } -> tensor<1xf32>
267+
%0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
268+
269+
// CHECK: return %[[RESULT]] : tensor<1xf32>
270+
return %0 : tensor<1xf32>
271+
}
272+
273+
// -----
274+
256275
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
257276
// CHECK-LABEL: @test_add_1d_matching_static
258277
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
@@ -1969,13 +1988,12 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
19691988

19701989
// -----
19711990

1972-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
1973-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
1991+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
19741992

19751993
// CHECK-LABEL: func.func @test_cast_fp32_i64(
19761994
// CHECK-SAME: %[[ARG0:.*]]: tensor<1xf32>) -> tensor<1xi64> {
19771995
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1xi64>
1978-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<1xi64>) {
1996+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<1xi64>) {
19791997
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: i64):
19801998
// CHECK: %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
19811999
// CHECK: %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32

0 commit comments

Comments
 (0)