diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 60cae77644291..f4b6955823085 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1548,10 +1548,9 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op, /// Modify the `linalg.index` operations in the original generic op, to its /// value in the collapsed operation. -void generateCollapsedIndexingRegion(Location loc, Block *block, - const CollapsingInfo &collapsingInfo, - ValueRange loopRange, - RewriterBase &rewriter) { +static void generateCollapsedIndexingRegion( + Location loc, Block *block, const CollapsingInfo &collapsingInfo, + ArrayRef loopRange, RewriterBase &rewriter) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointToStart(block); @@ -1572,10 +1571,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, Value newIndexVal = rewriter.create(loc, foldedDims.index()); for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { + Value loopDim = + getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]); indexReplacementVals[dim] = - rewriter.create(loc, newIndexVal, loopRange[dim]); + rewriter.createOrFold(loc, newIndexVal, loopDim); newIndexVal = - rewriter.create(loc, newIndexVal, loopRange[dim]); + rewriter.createOrFold(loc, newIndexVal, loopDim); } indexReplacementVals[foldedDims.value().front()] = newIndexVal; } @@ -1722,14 +1723,13 @@ FailureOr mlir::linalg::collapseOpIterationDims( LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter); Location loc = op->getLoc(); + SmallVector loopBound = + llvm::map_to_vector(loopRanges, [](Range range) { return range.size; }); + if (collapsedOp.hasIndexSemantics()) { // Collect the loop range of the generic op. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(collapsedOp); - SmallVector loopBound = - llvm::map_to_vector(loopRanges, [&](Range range) { - return getValueOrCreateConstantIndexOp(rewriter, loc, range.size); - }); generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(), collapsingInfo, loopBound, rewriter); } @@ -1747,15 +1747,22 @@ FailureOr mlir::linalg::collapseOpIterationDims( op.getIndexingMapMatchingResult(originalResult.value()); SmallVector reassociation = getOperandReassociation(indexingMap, collapsingInfo); + assert( + indexingMap.isProjectedPermutation() && + "Expected indexing map to be a projected permutation for collapsing"); + SmallVector resultShape = + applyPermutationMap(indexingMap, ArrayRef(loopBound)); Value result; if (isa(collapsedOpResult.getType())) { MemRefType expandShapeResultType = MemRefType::get( originalResultType.getShape(), originalResultType.getElementType()); result = rewriter.create( - loc, expandShapeResultType, collapsedOpResult, reassociation); + loc, expandShapeResultType, collapsedOpResult, reassociation, + resultShape); } else { result = rewriter.create( - loc, originalResultType, collapsedOpResult, reassociation); + loc, originalResultType, collapsedOpResult, reassociation, + resultShape); } results.push_back(result); } else { diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 7db997cd4c0b5..89734e7542801 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor, // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor, %sz0: index, %sz1: index) -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor into tensor + %init = tensor.empty(%sz1, %sz0) : tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor) + outs(%init : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %out = arith.negf %b0 : f32 + linalg.yield %out : f32 + } -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2 +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]] +// CHECK: %[[OUT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs(%{{.*}} : tensor) +// CHECK: %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]] +// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]]] +// CHECK: return %[[EXPANDED_1]] + +// ----- + #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> { @@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @fuse_only_one_reassociation // CHECK-SAME: (%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8] +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor // CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} // CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} // CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} @@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4 // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] : // CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] : -// CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32> -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32> -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index -// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32> +// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32> // CHECK: return %[[EXPANDED_3]] // ----- @@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor, %sz0: index, %sz1: // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: func @fold_non_consecutive_dims( // CHECK-SAME: %[[ARG0:.+]]: tensor, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index) -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C4:.+]] = arith.constant 4 : index -// CHECK: %[[C8:.+]] = arith.constant 8 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor into tensor -// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]] +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] +// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]] // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]]) +// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] +// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]] // CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}} // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor, %sz0: index, %sz1: // CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]] // CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]] // CHECK: linalg.yield %[[T7]] -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index -// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index -// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor into tensor +// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor into tensor // CHECK: return %[[EXPANDED_3]] // ----- diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir index 7acbd843cd1e7..fd3c321722508 100644 --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -5,15 +5,14 @@ // CHECK-LABEL: func @reshape // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor, %[[SZ0:.*]]: index) -// CHECK: %[[C112:.*]] = arith.constant 112 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]] +// CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]] // CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor into tensor // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor<16xf32>) outs(%[[RI]] : tensor) -// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor -// CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index -// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor into tensor +// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[DIM]], 112, 16] : tensor into tensor // CHECK: return %[[RR]] : tensor func.func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor, %sz0: index) -> tensor { %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]