Skip to content

Commit 796d48b

Browse files
authored
[mlir][vector] Add leading unit dim folding patterns for masked transfers (#71466)
This handles `vector.transfer_read`, `vector.transfer_write`, and `vector.constant_mask`. The unit dims are only relevant for masks created by `create_mask` and `constant_mask` if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.
1 parent bdb309c commit 796d48b

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <numeric>
10+
911
#include "mlir/Dialect/Arith/IR/Arith.h"
1012
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1113
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
208210
if (read.getTransferRank() == 0)
209211
return failure();
210212

211-
if (read.getMask())
212-
return failure();
213-
214213
auto shapedType = cast<ShapedType>(read.getSource().getType());
215214
if (shapedType.getElementType() != read.getVectorType().getElementType())
216215
return failure();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
233232
inBoundsAttr = rewriter.getArrayAttr(
234233
read.getInBoundsAttr().getValue().take_back(newType.getRank()));
235234

235+
Value mask = Value();
236+
if (read.getMask()) {
237+
// The mask shape must always match the shape of the written vector, so we
238+
// can safely use the same extraction indices.
239+
int64_t dropDim = oldType.getRank() - newType.getRank();
240+
mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
241+
splatZero(dropDim));
242+
}
243+
236244
auto newRead = rewriter.create<vector::TransferReadOp>(
237245
read.getLoc(), newType, read.getSource(), read.getIndices(),
238-
AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
239-
inBoundsAttr);
246+
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
240247
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
241248

242249
return success();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
256263
if (write.getTransferRank() == 0)
257264
return failure();
258265

259-
if (write.getMask())
260-
return failure();
261-
262266
auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
263267
if (shapedType.getElementType() != write.getVectorType().getElementType())
264268
return failure();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
283287

284288
auto newVector = rewriter.create<vector::ExtractOp>(
285289
write.getLoc(), write.getVector(), splatZero(dropDim));
290+
291+
if (write.getMask()) {
292+
// The mask shape must always match the shape of the written vector, so we
293+
// can safely use the same extraction indices.
294+
auto newMask = rewriter.create<vector::ExtractOp>(
295+
write.getLoc(), write.getMask(), splatZero(dropDim));
296+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
297+
write, newVector, write.getSource(), write.getIndices(),
298+
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
299+
return success();
300+
}
301+
286302
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
287303
write, newVector, write.getSource(), write.getIndices(),
288304
AffineMapAttr::get(newMap), inBoundsAttr);
289-
290305
return success();
291306
}
292307
};
@@ -467,14 +482,48 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
467482
}
468483
};
469484

485+
// Drops leading 1 dimensions from vector.constant_mask and inserts a
486+
// vector.broadcast back to the original shape.
487+
struct CastAwayConstantMaskLeadingOneDim
488+
: public OpRewritePattern<vector::ConstantMaskOp> {
489+
using OpRewritePattern::OpRewritePattern;
490+
491+
LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
492+
PatternRewriter &rewriter) const override {
493+
VectorType oldType = mask.getType();
494+
VectorType newType = trimLeadingOneDims(oldType);
495+
496+
if (newType == oldType)
497+
return failure();
498+
499+
int64_t dropDim = oldType.getRank() - newType.getRank();
500+
SmallVector<int64_t> dimSizes;
501+
for (auto attr : mask.getMaskDimSizes())
502+
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
503+
504+
// If any of the dropped unit dims has a size of `0`, the entire mask is a
505+
// zero mask, else the unit dim has no effect on the mask.
506+
int64_t flatLeadingSize =
507+
std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
508+
static_cast<int64_t>(1), std::multiplies<int64_t>());
509+
SmallVector<int64_t> newDimSizes({flatLeadingSize});
510+
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
511+
512+
auto newMask = rewriter.create<vector::ConstantMaskOp>(
513+
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
514+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
515+
return success();
516+
}
517+
};
518+
470519
} // namespace
471520

472521
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
473522
RewritePatternSet &patterns, PatternBenefit benefit) {
474523
patterns
475524
.add<CastAwayExtractStridedSliceLeadingOneDim,
476525
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
477-
CastAwayTransferReadLeadingOneDim,
526+
CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
478527
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
479528
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
480529
populateShapeCastFoldingPatterns(patterns, benefit);

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
209209
return %0: vector<1x4xf16>
210210
}
211211

212+
// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
213+
func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
214+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
215+
%c0 = arith.constant 0 : index
216+
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
217+
%f0 = arith.constant 0. : f16
218+
// CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
219+
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
220+
// CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
221+
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
222+
// CHECK: return %[[CAST]]
223+
return %0: vector<1x4xf16>
224+
}
225+
212226
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
213227
func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
214228
%c0 = arith.constant 0 : index
@@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
229243
return
230244
}
231245

246+
// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
247+
func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
248+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
249+
%c0 = arith.constant 0 : index
250+
// CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
251+
// CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
252+
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
253+
254+
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
255+
return
256+
}
257+
232258
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
233259
func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
234260
%c0 = arith.constant 0 : index
@@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
410436
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
411437
return %0: vector<1x1x8x1x[8]xi1>
412438
}
439+
440+
// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
441+
// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
442+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
443+
// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1>
444+
func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
445+
%0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
446+
return %0: vector<1x1x8x2x1xi1>
447+
}

0 commit comments

Comments
 (0)