6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include < numeric>
10
+
9
11
#include " mlir/Dialect/Arith/IR/Arith.h"
10
12
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
11
13
#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
208
210
if (read .getTransferRank () == 0 )
209
211
return failure ();
210
212
211
- if (read .getMask ())
212
- return failure ();
213
-
214
213
auto shapedType = cast<ShapedType>(read .getSource ().getType ());
215
214
if (shapedType.getElementType () != read .getVectorType ().getElementType ())
216
215
return failure ();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
233
232
inBoundsAttr = rewriter.getArrayAttr (
234
233
read .getInBoundsAttr ().getValue ().take_back (newType.getRank ()));
235
234
235
+ Value mask = Value ();
236
+ if (read .getMask ()) {
237
+ // The mask shape must always match the shape of the written vector, so we
238
+ // can safely use the same extraction indices.
239
+ int64_t dropDim = oldType.getRank () - newType.getRank ();
240
+ mask = rewriter.create <vector::ExtractOp>(read .getLoc (), read .getMask (),
241
+ splatZero (dropDim));
242
+ }
243
+
236
244
auto newRead = rewriter.create <vector::TransferReadOp>(
237
245
read .getLoc (), newType, read .getSource (), read .getIndices (),
238
- AffineMapAttr::get (newMap), read .getPadding (), /* mask=*/ Value (),
239
- inBoundsAttr);
246
+ AffineMapAttr::get (newMap), read .getPadding (), mask, inBoundsAttr);
240
247
rewriter.replaceOpWithNewOp <vector::BroadcastOp>(read , oldType, newRead);
241
248
242
249
return success ();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
256
263
if (write .getTransferRank () == 0 )
257
264
return failure ();
258
265
259
- if (write .getMask ())
260
- return failure ();
261
-
262
266
auto shapedType = dyn_cast<ShapedType>(write .getSource ().getType ());
263
267
if (shapedType.getElementType () != write .getVectorType ().getElementType ())
264
268
return failure ();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
283
287
284
288
auto newVector = rewriter.create <vector::ExtractOp>(
285
289
write .getLoc (), write .getVector (), splatZero (dropDim));
290
+
291
+ if (write .getMask ()) {
292
+ // The mask shape must always match the shape of the written vector, so we
293
+ // can safely use the same extraction indices.
294
+ auto newMask = rewriter.create <vector::ExtractOp>(
295
+ write .getLoc (), write .getMask (), splatZero (dropDim));
296
+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
297
+ write , newVector, write .getSource (), write .getIndices (),
298
+ AffineMapAttr::get (newMap), newMask, inBoundsAttr);
299
+ return success ();
300
+ }
301
+
286
302
rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
287
303
write , newVector, write .getSource (), write .getIndices (),
288
304
AffineMapAttr::get (newMap), inBoundsAttr);
289
-
290
305
return success ();
291
306
}
292
307
};
@@ -467,14 +482,48 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
467
482
}
468
483
};
469
484
485
+ // Drops leading 1 dimensions from vector.constant_mask and inserts a
486
+ // vector.broadcast back to the original shape.
487
+ struct CastAwayConstantMaskLeadingOneDim
488
+ : public OpRewritePattern<vector::ConstantMaskOp> {
489
+ using OpRewritePattern::OpRewritePattern;
490
+
491
+ LogicalResult matchAndRewrite (vector::ConstantMaskOp mask,
492
+ PatternRewriter &rewriter) const override {
493
+ VectorType oldType = mask.getType ();
494
+ VectorType newType = trimLeadingOneDims (oldType);
495
+
496
+ if (newType == oldType)
497
+ return failure ();
498
+
499
+ int64_t dropDim = oldType.getRank () - newType.getRank ();
500
+ SmallVector<int64_t > dimSizes;
501
+ for (auto attr : mask.getMaskDimSizes ())
502
+ dimSizes.push_back (llvm::cast<IntegerAttr>(attr).getInt ());
503
+
504
+ // If any of the dropped unit dims has a size of `0`, the entire mask is a
505
+ // zero mask, else the unit dim has no effect on the mask.
506
+ int64_t flatLeadingSize =
507
+ std::accumulate (dimSizes.begin (), dimSizes.begin () + dropDim + 1 ,
508
+ static_cast <int64_t >(1 ), std::multiplies<int64_t >());
509
+ SmallVector<int64_t > newDimSizes ({flatLeadingSize});
510
+ newDimSizes.append (dimSizes.begin () + dropDim + 1 , dimSizes.end ());
511
+
512
+ auto newMask = rewriter.create <vector::ConstantMaskOp>(
513
+ mask.getLoc (), newType, rewriter.getI64ArrayAttr (newDimSizes));
514
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(mask, oldType, newMask);
515
+ return success ();
516
+ }
517
+ };
518
+
470
519
} // namespace
471
520
472
521
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns (
473
522
RewritePatternSet &patterns, PatternBenefit benefit) {
474
523
patterns
475
524
.add <CastAwayExtractStridedSliceLeadingOneDim,
476
525
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
477
- CastAwayTransferReadLeadingOneDim,
526
+ CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
478
527
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
479
528
CastAwayContractionLeadingOneDim>(patterns.getContext (), benefit);
480
529
populateShapeCastFoldingPatterns (patterns, benefit);
0 commit comments