@@ -2686,6 +2686,90 @@ static LogicalResult verify(AffineVectorStoreOp op) {
2686
2686
return success ();
2687
2687
}
2688
2688
2689
+ namespace {
2690
+ // / This pattern removes affine.parallel ops with no induction variables
2691
+ struct AffineParallelRank0LoopRemover
2692
+ : public OpRewritePattern<AffineParallelOp> {
2693
+ using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
2694
+
2695
+ LogicalResult matchAndRewrite (AffineParallelOp op,
2696
+ PatternRewriter &rewriter) const override {
2697
+ // Check that there are no induction variables
2698
+ if (op.lowerBoundsMap ().getNumResults () != 0 )
2699
+ return failure ();
2700
+ // Remove the affine.parallel wrapper, retain the body in the same location
2701
+ auto &parentOps = rewriter.getInsertionBlock ()->getOperations ();
2702
+ auto ¶llelBodyOps = op.region ().front ().getOperations ();
2703
+ auto yield = mlir::cast<AffineYieldOp>(std::prev (parallelBodyOps.end ()));
2704
+ for (auto it : zip (op.getResults (), yield.results ())) {
2705
+ std::get<0 >(it).replaceAllUsesWith (std::get<1 >(it));
2706
+ }
2707
+ parentOps.splice (mlir::Block::iterator (op), parallelBodyOps,
2708
+ parallelBodyOps.begin (), std::prev (parallelBodyOps.end ()));
2709
+ rewriter.eraseOp (op);
2710
+ return success ();
2711
+ }
2712
+ };
2713
+
2714
+ // / This pattern removes indexs that go over an empty range
2715
+ struct AffineParallelRange1IndexRemover
2716
+ : public OpRewritePattern<AffineParallelOp> {
2717
+ using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
2718
+
2719
+ LogicalResult matchAndRewrite (AffineParallelOp op,
2720
+ PatternRewriter &rewriter) const override {
2721
+ auto ranges = op.getRangesValueMap ();
2722
+ auto origNumArgs = op.getBody ()->getArguments ().size ();
2723
+ size_t curArgNum = 0 ;
2724
+ SmallVector<AffineExpr, 6 > newLowerBounds;
2725
+ SmallVector<AffineExpr, 6 > newUpperBounds;
2726
+ SmallVector<int64_t , 6 > newSteps;
2727
+ for (unsigned i = 0 ; i < origNumArgs; i++) {
2728
+ // Is the range a constant value of 1?
2729
+ auto const_expr = ranges.getResult (i).dyn_cast <AffineConstantExpr>();
2730
+ if (const_expr && const_expr.getValue () == 1 ) {
2731
+ // Remove arcument and replace with 0
2732
+ auto curArg = op.getBody ()->getArgument (curArgNum);
2733
+ auto lowerBoundValue = rewriter.create <AffineApplyOp>(
2734
+ op.getLoc (), op.lowerBoundsMap ().getSubMap ({i}), op.getLowerBoundsOperands ());
2735
+ curArg.replaceAllUsesWith (lowerBoundValue);
2736
+ op.getBody ()->eraseArgument (curArgNum);
2737
+ } else {
2738
+ // Keep argument
2739
+ newLowerBounds.push_back (op.lowerBoundsMap ().getResult (i));
2740
+ newUpperBounds.push_back (op.upperBoundsMap ().getResult (i));
2741
+ newSteps.push_back (
2742
+ op.steps ()[i].template cast <IntegerAttr>().getInt ());
2743
+ curArgNum++;
2744
+ }
2745
+ }
2746
+ // If no arguments were removed, return failur to match
2747
+ if (newLowerBounds.size () == op.lowerBoundsMap ().getNumResults ())
2748
+ return failure ();
2749
+ // Update attributes and return success
2750
+ auto newLower = AffineMap::get (op.lowerBoundsMap ().getNumDims (),
2751
+ op.lowerBoundsMap ().getNumSymbols (),
2752
+ newLowerBounds, op.getContext ());
2753
+ auto newUpper = AffineMap::get (op.upperBoundsMap ().getNumDims (),
2754
+ op.upperBoundsMap ().getNumSymbols (),
2755
+ newUpperBounds, op.getContext ());
2756
+ op.setAttr (AffineParallelOp::getLowerBoundsMapAttrName (),
2757
+ AffineMapAttr::get (newLower));
2758
+ op.setAttr (AffineParallelOp::getUpperBoundsMapAttrName (),
2759
+ AffineMapAttr::get (newUpper));
2760
+ op.setAttr (AffineParallelOp::getStepsAttrName (),
2761
+ rewriter.getI64ArrayAttr (newSteps));
2762
+ return success ();
2763
+ }
2764
+ };
2765
+
2766
+ } // end anonymous namespace
2767
+
2768
+ void AffineParallelOp::getCanonicalizationPatterns (
2769
+ OwningRewritePatternList &results, MLIRContext *context) {
2770
+ results.insert <AffineParallelRank0LoopRemover, AffineParallelRange1IndexRemover>(context);
2771
+ }
2772
+
2689
2773
// ===----------------------------------------------------------------------===//
2690
2774
// TableGen'd op method definitions
2691
2775
// ===----------------------------------------------------------------------===//
0 commit comments