Skip to content

Commit 7a6ff56

Browse files
jbruestletzerrell
andauthored
Add affine canonicalization (#11)
* Add canonicalizer removing rank 0 affine.parallel * Add canonicalizer removing range 1 indexes from affine.parallel loops. Co-authored-by: Tim Zerrell <[email protected]>
1 parent e41f9b1 commit 7a6ff56

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,8 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
593593
static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; }
594594
static StringRef getStepsAttrName() { return "steps"; }
595595
}];
596+
597+
let hasCanonicalizer = 1;
596598
}
597599

598600
def AffinePrefetchOp : Affine_Op<"prefetch"> {

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,6 +2686,90 @@ static LogicalResult verify(AffineVectorStoreOp op) {
26862686
return success();
26872687
}
26882688

2689+
namespace {
2690+
/// This pattern removes affine.parallel ops with no induction variables
2691+
struct AffineParallelRank0LoopRemover
2692+
: public OpRewritePattern<AffineParallelOp> {
2693+
using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
2694+
2695+
LogicalResult matchAndRewrite(AffineParallelOp op,
2696+
PatternRewriter &rewriter) const override {
2697+
// Check that there are no induction variables
2698+
if (op.lowerBoundsMap().getNumResults() != 0)
2699+
return failure();
2700+
// Remove the affine.parallel wrapper, retain the body in the same location
2701+
auto &parentOps = rewriter.getInsertionBlock()->getOperations();
2702+
auto &parallelBodyOps = op.region().front().getOperations();
2703+
auto yield = mlir::cast<AffineYieldOp>(std::prev(parallelBodyOps.end()));
2704+
for (auto it : zip(op.getResults(), yield.results())) {
2705+
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2706+
}
2707+
parentOps.splice(mlir::Block::iterator(op), parallelBodyOps,
2708+
parallelBodyOps.begin(), std::prev(parallelBodyOps.end()));
2709+
rewriter.eraseOp(op);
2710+
return success();
2711+
}
2712+
};
2713+
2714+
/// This pattern removes indexs that go over an empty range
2715+
struct AffineParallelRange1IndexRemover
2716+
: public OpRewritePattern<AffineParallelOp> {
2717+
using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
2718+
2719+
LogicalResult matchAndRewrite(AffineParallelOp op,
2720+
PatternRewriter &rewriter) const override {
2721+
auto ranges = op.getRangesValueMap();
2722+
auto origNumArgs = op.getBody()->getArguments().size();
2723+
size_t curArgNum = 0;
2724+
SmallVector<AffineExpr, 6> newLowerBounds;
2725+
SmallVector<AffineExpr, 6> newUpperBounds;
2726+
SmallVector<int64_t, 6> newSteps;
2727+
for (unsigned i = 0; i < origNumArgs; i++) {
2728+
// Is the range a constant value of 1?
2729+
auto const_expr = ranges.getResult(i).dyn_cast<AffineConstantExpr>();
2730+
if (const_expr && const_expr.getValue() == 1) {
2731+
// Remove arcument and replace with 0
2732+
auto curArg = op.getBody()->getArgument(curArgNum);
2733+
auto lowerBoundValue = rewriter.create<AffineApplyOp>(
2734+
op.getLoc(), op.lowerBoundsMap().getSubMap({i}), op.getLowerBoundsOperands());
2735+
curArg.replaceAllUsesWith(lowerBoundValue);
2736+
op.getBody()->eraseArgument(curArgNum);
2737+
} else {
2738+
// Keep argument
2739+
newLowerBounds.push_back(op.lowerBoundsMap().getResult(i));
2740+
newUpperBounds.push_back(op.upperBoundsMap().getResult(i));
2741+
newSteps.push_back(
2742+
op.steps()[i].template cast<IntegerAttr>().getInt());
2743+
curArgNum++;
2744+
}
2745+
}
2746+
// If no arguments were removed, return failur to match
2747+
if (newLowerBounds.size() == op.lowerBoundsMap().getNumResults())
2748+
return failure();
2749+
// Update attributes and return success
2750+
auto newLower = AffineMap::get(op.lowerBoundsMap().getNumDims(),
2751+
op.lowerBoundsMap().getNumSymbols(),
2752+
newLowerBounds, op.getContext());
2753+
auto newUpper = AffineMap::get(op.upperBoundsMap().getNumDims(),
2754+
op.upperBoundsMap().getNumSymbols(),
2755+
newUpperBounds, op.getContext());
2756+
op.setAttr(AffineParallelOp::getLowerBoundsMapAttrName(),
2757+
AffineMapAttr::get(newLower));
2758+
op.setAttr(AffineParallelOp::getUpperBoundsMapAttrName(),
2759+
AffineMapAttr::get(newUpper));
2760+
op.setAttr(AffineParallelOp::getStepsAttrName(),
2761+
rewriter.getI64ArrayAttr(newSteps));
2762+
return success();
2763+
}
2764+
};
2765+
2766+
} // end anonymous namespace
2767+
2768+
void AffineParallelOp::getCanonicalizationPatterns(
2769+
OwningRewritePatternList &results, MLIRContext *context) {
2770+
results.insert<AffineParallelRank0LoopRemover, AffineParallelRange1IndexRemover>(context);
2771+
}
2772+
26892773
//===----------------------------------------------------------------------===//
26902774
// TableGen'd op method definitions
26912775
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,18 @@ func @drop_duplicate_bounds(%N : index) {
604604
}
605605
return
606606
}
607+
608+
// -----
609+
610+
// CHECK: func @remove_rank0_affine_parallel(%[[OUT:.*]]: memref<f32>)
611+
func @remove_rank0_affine_parallel(%out: memref<f32>) {
612+
// CHECK-NEXT: %[[CST:.*]] = constant
613+
%cst = constant 0.0 : f32
614+
// CHECK-NEXT: affine.store %[[CST]], %[[OUT]][] : memref<f32>
615+
affine.parallel () = () to () {
616+
affine.parallel () = () to () {
617+
affine.store %cst, %out[] : memref<f32>
618+
}
619+
}
620+
return
621+
}

0 commit comments

Comments
 (0)