Skip to content

Commit 440f02e

Browse files
gitoleglanza
authored andcommitted
[CIR][Codegen] IfOp flattening (llvm#537)
This PR perform flattening for `cir::IfOp` Basically, we just move the code from `LowerToLLVM.cpp` to `FlattenCFG.cpp`. There are several important things though I would like to highlight. 1) Consider the next code from the tests: ``` cir.func @foo(%arg0: !s32i) -> !s32i { %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool cir.if %4 { %5 = cir.const(#cir.int<1> : !s32i) : !s32i cir.return %5 : !s32i } else { %5 = cir.const(#cir.int<0> : !s32i) : !s32i cir.return %5 : !s32i } cir.return %arg0 : !s32i } ``` The last `cir.return` becomes unreachable after flattening and hence is not reachable in the lowering. So we got the next error: ``` error: 'cir.return' op expects parent op to be one of 'cir.func, cir.scope, cir.if, cir.switch, cir.do, cir.while, cir.for' cir.return %arg0 : !s32i ``` the parent after lowering is `llvm.func`. And this is only the beginning - the more operations will be flatten, the more similar fails will happen. Thus, I added lowering for the unreachable code as well in `LowerToLLVM.cpp`. But may be you have another solution in your mind. 2) Please, pay attention on the flattening pass - I'm not that familiar with `mlir` builders as you are, so may be I'm doing something wrong. The idea was to start flattening from the most nested operations. 3) As you requested in llvm#516, `cir-to-llvm-internal` is renamed to `cir-flat-to-llvm`. The only thing remain undone is related to the following: > Since it would be wrong to run cir-flat-to-llvm without running cir-flatten-cfg, we should make cir-flat-to-llvm pass to require cir-flatten-cfg pass to be run before. And I'm not sure I know how to do it exactly - is there something similar to pass dependencies from LLVM IR? 4) The part of `IfOp` lowering related to elimination of the vain casts for condition branch moved directly to the lowering of `BrCondOp` with some refactoring and guarding. 5) Just note, that now `cir-opt` is able to dump the flat cir as well: `cir-opt -cir-flat-cfg`
1 parent 84ebc8c commit 440f02e

File tree

5 files changed

+203
-98
lines changed

5 files changed

+203
-98
lines changed

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,70 @@ struct FlattenCFGPass : public FlattenCFGBase<FlattenCFGPass> {
3030
void runOnOperation() override;
3131
};
3232

33+
struct CIRIfFlattening : public OpRewritePattern<IfOp> {
34+
using OpRewritePattern<IfOp>::OpRewritePattern;
35+
36+
mlir::LogicalResult
37+
matchAndRewrite(mlir::cir::IfOp ifOp,
38+
mlir::PatternRewriter &rewriter) const override {
39+
mlir::OpBuilder::InsertionGuard guard(rewriter);
40+
auto loc = ifOp.getLoc();
41+
auto emptyElse = ifOp.getElseRegion().empty();
42+
43+
auto *currentBlock = rewriter.getInsertionBlock();
44+
auto *remainingOpsBlock =
45+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
46+
mlir::Block *continueBlock;
47+
if (ifOp->getResults().size() == 0)
48+
continueBlock = remainingOpsBlock;
49+
else
50+
llvm_unreachable("NYI");
51+
52+
// Inline then region
53+
auto *thenBeforeBody = &ifOp.getThenRegion().front();
54+
auto *thenAfterBody = &ifOp.getThenRegion().back();
55+
rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
56+
57+
rewriter.setInsertionPointToEnd(thenAfterBody);
58+
if (auto thenYieldOp =
59+
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
60+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
61+
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
62+
}
63+
64+
rewriter.setInsertionPointToEnd(continueBlock);
65+
66+
// Has else region: inline it.
67+
mlir::Block *elseBeforeBody = nullptr;
68+
mlir::Block *elseAfterBody = nullptr;
69+
if (!emptyElse) {
70+
elseBeforeBody = &ifOp.getElseRegion().front();
71+
elseAfterBody = &ifOp.getElseRegion().back();
72+
rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody);
73+
} else {
74+
elseBeforeBody = elseAfterBody = continueBlock;
75+
}
76+
77+
rewriter.setInsertionPointToEnd(currentBlock);
78+
rewriter.create<mlir::cir::BrCondOp>(loc, ifOp.getCondition(),
79+
thenBeforeBody, elseBeforeBody);
80+
81+
if (!emptyElse) {
82+
rewriter.setInsertionPointToEnd(elseAfterBody);
83+
if (auto elseYieldOp =
84+
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
85+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
86+
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
87+
}
88+
}
89+
90+
rewriter.replaceOp(ifOp, continueBlock->getArguments());
91+
return mlir::success();
92+
}
93+
};
94+
3395
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
34-
// TODO: add patterns here
96+
patterns.add<CIRIfFlattening>(patterns.getContext());
3597
}
3698

3799
void FlattenCFGPass::runOnOperation() {
@@ -41,7 +103,8 @@ void FlattenCFGPass::runOnOperation() {
41103
// Collect operations to apply patterns.
42104
SmallVector<Operation *, 16> ops;
43105
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
44-
// TODO: push back operations here
106+
if (isa<IfOp>(op))
107+
ops.push_back(op);
45108
});
46109

47110
// Apply patterns.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 85 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
#include "llvm/Support/Casting.h"
6363
#include "llvm/Support/ErrorHandling.h"
6464
#include <cstdint>
65+
#include <deque>
6566
#include <optional>
67+
#include <set>
6668

6769
using namespace cir;
6870
using namespace llvm;
@@ -561,13 +563,25 @@ class CIRBrCondOpLowering
561563
mlir::LogicalResult
562564
matchAndRewrite(mlir::cir::BrCondOp brOp, OpAdaptor adaptor,
563565
mlir::ConversionPatternRewriter &rewriter) const override {
564-
auto condition = adaptor.getCond();
565-
auto i1Condition = rewriter.create<mlir::LLVM::TruncOp>(
566-
brOp.getLoc(), rewriter.getI1Type(), condition);
566+
mlir::Value i1Condition;
567+
568+
if (auto defOp = adaptor.getCond().getDefiningOp()) {
569+
if (auto zext = dyn_cast<mlir::LLVM::ZExtOp>(defOp)) {
570+
if (zext->use_empty() &&
571+
zext->getOperand(0).getType() == rewriter.getI1Type()) {
572+
i1Condition = zext->getOperand(0);
573+
rewriter.eraseOp(zext);
574+
}
575+
}
576+
}
577+
578+
if (!i1Condition)
579+
i1Condition = rewriter.create<mlir::LLVM::TruncOp>(
580+
brOp.getLoc(), rewriter.getI1Type(), adaptor.getCond());
581+
567582
rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
568-
brOp, i1Condition.getResult(), brOp.getDestTrue(),
569-
adaptor.getDestOperandsTrue(), brOp.getDestFalse(),
570-
adaptor.getDestOperandsFalse());
583+
brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(),
584+
brOp.getDestFalse(), adaptor.getDestOperandsFalse());
571585

572586
return mlir::success();
573587
}
@@ -771,90 +785,6 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
771785
}
772786
};
773787

774-
class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
775-
public:
776-
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
777-
778-
mlir::LogicalResult
779-
matchAndRewrite(mlir::cir::IfOp ifOp, OpAdaptor adaptor,
780-
mlir::ConversionPatternRewriter &rewriter) const override {
781-
mlir::OpBuilder::InsertionGuard guard(rewriter);
782-
auto loc = ifOp.getLoc();
783-
auto emptyElse = ifOp.getElseRegion().empty();
784-
785-
auto *currentBlock = rewriter.getInsertionBlock();
786-
auto *remainingOpsBlock =
787-
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
788-
mlir::Block *continueBlock;
789-
if (ifOp->getResults().size() == 0)
790-
continueBlock = remainingOpsBlock;
791-
else
792-
llvm_unreachable("NYI");
793-
794-
// Inline then region
795-
auto *thenBeforeBody = &ifOp.getThenRegion().front();
796-
auto *thenAfterBody = &ifOp.getThenRegion().back();
797-
rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
798-
799-
rewriter.setInsertionPointToEnd(thenAfterBody);
800-
if (auto thenYieldOp =
801-
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
802-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
803-
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
804-
}
805-
806-
rewriter.setInsertionPointToEnd(continueBlock);
807-
808-
// Has else region: inline it.
809-
mlir::Block *elseBeforeBody = nullptr;
810-
mlir::Block *elseAfterBody = nullptr;
811-
if (!emptyElse) {
812-
elseBeforeBody = &ifOp.getElseRegion().front();
813-
elseAfterBody = &ifOp.getElseRegion().back();
814-
rewriter.inlineRegionBefore(ifOp.getElseRegion(), thenAfterBody);
815-
} else {
816-
elseBeforeBody = elseAfterBody = continueBlock;
817-
}
818-
819-
rewriter.setInsertionPointToEnd(currentBlock);
820-
821-
// FIXME: CIR always lowers !cir.bool to i8 type.
822-
// In this reason CIR CodeGen often emits the redundant zext + trunc
823-
// sequence that prevents lowering of llvm.expect in
824-
// LowerExpectIntrinsicPass.
825-
// We should fix that in a more appropriate way. But as a temporary solution
826-
// just avoid the redundant casts here.
827-
mlir::Value condition;
828-
auto zext =
829-
dyn_cast<mlir::LLVM::ZExtOp>(adaptor.getCondition().getDefiningOp());
830-
if (zext && zext->getOperand(0).getType() == rewriter.getI1Type()) {
831-
condition = zext->getOperand(0);
832-
if (zext->use_empty())
833-
rewriter.eraseOp(zext);
834-
} else {
835-
auto trunc = rewriter.create<mlir::LLVM::TruncOp>(
836-
loc, rewriter.getI1Type(), adaptor.getCondition());
837-
condition = trunc.getRes();
838-
}
839-
840-
rewriter.create<mlir::LLVM::CondBrOp>(loc, condition, thenBeforeBody,
841-
elseBeforeBody);
842-
843-
if (!emptyElse) {
844-
rewriter.setInsertionPointToEnd(elseAfterBody);
845-
if (auto elseYieldOp =
846-
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
847-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
848-
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
849-
}
850-
}
851-
852-
rewriter.replaceOp(ifOp, continueBlock->getArguments());
853-
854-
return mlir::success();
855-
}
856-
};
857-
858788
class CIRScopeOpLowering
859789
: public mlir::OpConversionPattern<mlir::cir::ScopeOp> {
860790
public:
@@ -937,9 +867,7 @@ struct ConvertCIRToLLVMPass
937867
}
938868
void runOnOperation() final;
939869

940-
virtual StringRef getArgument() const override {
941-
return "cir-to-llvm-internal";
942-
}
870+
virtual StringRef getArgument() const override { return "cir-flat-to-llvm"; }
943871
};
944872

945873
class CIRCallLowering : public mlir::OpConversionPattern<mlir::cir::CallOp> {
@@ -3081,7 +3009,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
30813009
CIRLoopOpInterfaceLowering, CIRBrCondOpLowering, CIRPtrStrideOpLowering,
30823010
CIRCallLowering, CIRUnaryOpLowering, CIRBinOpLowering, CIRShiftOpLowering,
30833011
CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRAllocaLowering,
3084-
CIRFuncLowering, CIRScopeOpLowering, CIRCastOpLowering, CIRIfLowering,
3012+
CIRFuncLowering, CIRScopeOpLowering, CIRCastOpLowering,
30853013
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
30863014
CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
30873015
CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering,
@@ -3241,6 +3169,64 @@ static void buildCtorDtorList(
32413169
builder.create<mlir::LLVM::ReturnOp>(loc, result);
32423170
}
32433171

3172+
// The unreachable code is not lowered by applyPartialConversion function
3173+
// since it traverses blocks in the dominance order. At the same time we
3174+
// do need to lower such code - otherwise verification errors occur.
3175+
// For instance, the next CIR code:
3176+
//
3177+
// cir.func @foo(%arg0: !s32i) -> !s32i {
3178+
// %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
3179+
// cir.if %4 {
3180+
// %5 = cir.const(#cir.int<1> : !s32i) : !s32i
3181+
// cir.return %5 : !s32i
3182+
// } else {
3183+
// %5 = cir.const(#cir.int<0> : !s32i) : !s32i
3184+
// cir.return %5 : !s32i
3185+
// }
3186+
// cir.return %arg0 : !s32i
3187+
// }
3188+
//
3189+
// contains an unreachable return operation (the last one). After the flattening
3190+
// pass it will be placed into the unreachable block. And the possible error
3191+
// after the lowering pass is: error: 'cir.return' op expects parent op to be
3192+
// one of 'cir.func, cir.scope, cir.if ... The reason that this operation was
3193+
// not lowered and the new parent is lllvm.func.
3194+
//
3195+
// In the future we may want to get rid of this function and use DCE pass or
3196+
// something similar. But now we need to guarantee the absence of the dialect
3197+
// verification errors.
3198+
void collect_unreachable(mlir::Operation *parent,
3199+
llvm::SmallVector<mlir::Operation *> &ops) {
3200+
3201+
llvm::SmallVector<mlir::Block *> unreachable_blocks;
3202+
parent->walk([&](mlir::Block *blk) { // check
3203+
if (blk->hasNoPredecessors() && !blk->isEntryBlock())
3204+
unreachable_blocks.push_back(blk);
3205+
});
3206+
3207+
std::set<mlir::Block *> visited;
3208+
for (auto *root : unreachable_blocks) {
3209+
// We create a work list for each unreachable block.
3210+
// Thus we traverse operations in some order.
3211+
std::deque<mlir::Block *> workList;
3212+
workList.push_back(root);
3213+
3214+
while (!workList.empty()) {
3215+
auto *blk = workList.back();
3216+
workList.pop_back();
3217+
if (visited.count(blk))
3218+
continue;
3219+
visited.emplace(blk);
3220+
3221+
for (auto &op : *blk)
3222+
ops.push_back(&op);
3223+
3224+
for (auto it = blk->succ_begin(); it != blk->succ_end(); ++it)
3225+
workList.push_back(*it);
3226+
}
3227+
}
3228+
}
3229+
32443230
void ConvertCIRToLLVMPass::runOnOperation() {
32453231
auto module = getOperation();
32463232
mlir::DataLayout dataLayout(module);
@@ -3280,7 +3266,11 @@ void ConvertCIRToLLVMPass::runOnOperation() {
32803266
getOperation()->removeAttr("cir.sob");
32813267
getOperation()->removeAttr("cir.lang");
32823268

3283-
if (failed(applyPartialConversion(module, target, std::move(patterns))))
3269+
llvm::SmallVector<mlir::Operation *> ops;
3270+
ops.push_back(module);
3271+
collect_unreachable(module, ops);
3272+
3273+
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
32843274
signalPassFailure();
32853275

32863276
// Emit the llvm.global_ctors array.

clang/test/CIR/CodeGen/if.cir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @foo(%arg0: !s32i) -> !s32i {
7+
%4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
8+
cir.if %4 {
9+
%5 = cir.const(#cir.int<1> : !s32i) : !s32i
10+
cir.return %5 : !s32i
11+
} else {
12+
%5 = cir.const(#cir.int<0> : !s32i) : !s32i
13+
cir.return %5 : !s32i
14+
}
15+
cir.return %arg0 : !s32i
16+
}
17+
// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
18+
// CHECK-NEXT: %0 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
19+
// CHECK-NEXT: cir.brcond %0 ^bb2, ^bb1
20+
// CHECK-NEXT: ^bb1: // pred: ^bb0
21+
// CHECK-NEXT: %1 = cir.const(#cir.int<0> : !s32i) : !s32i
22+
// CHECK-NEXT: cir.return %1 : !s32i
23+
// CHECK-NEXT: ^bb2: // pred: ^bb0
24+
// CHECK-NEXT: %2 = cir.const(#cir.int<1> : !s32i) : !s32i
25+
// CHECK-NEXT: cir.return %2 : !s32i
26+
// CHECK-NEXT: ^bb3: // no predecessors
27+
// CHECK-NEXT: cir.return %arg0 : !s32i
28+
// CHECK-NEXT: }
29+
30+
cir.func @onlyIf(%arg0: !s32i) -> !s32i {
31+
%4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
32+
cir.if %4 {
33+
%5 = cir.const(#cir.int<1> : !s32i) : !s32i
34+
cir.return %5 : !s32i
35+
}
36+
cir.return %arg0 : !s32i
37+
}
38+
// CHECK: cir.func @onlyIf(%arg0: !s32i) -> !s32i {
39+
// CHECK-NEXT: %0 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
40+
// CHECK-NEXT: cir.brcond %0 ^bb1, ^bb2
41+
// CHECK-NEXT: ^bb1: // pred: ^bb0
42+
// CHECK-NEXT: %1 = cir.const(#cir.int<1> : !s32i) : !s32i
43+
// CHECK-NEXT: cir.return %1 : !s32i
44+
// CHECK-NEXT: ^bb2: // pred: ^bb0
45+
// CHECK-NEXT: cir.return %arg0 : !s32i
46+
// CHECK-NEXT: }
47+
48+
}

clang/test/CIR/mlirprint.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int foo(void) {
2424
// CIRFLAT: IR Dump After FlattenCFG (cir-flatten-cfg)
2525
// CIRFLAT: IR Dump After DropAST (cir-drop-ast)
2626
// CIRFLAT: cir.func @foo() -> !s32i
27-
// LLVM: IR Dump After cir::direct::ConvertCIRToLLVMPass (cir-to-llvm-internal)
27+
// LLVM: IR Dump After cir::direct::ConvertCIRToLLVMPass (cir-flat-to-llvm)
2828
// LLVM: llvm.func @foo() -> i32
2929
// LLVM: IR Dump After
3030
// LLVM: define i32 @foo()

clang/tools/cir-opt/cir-opt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ int main(int argc, char **argv) {
5151
cir::direct::populateCIRToLLVMPasses(pm);
5252
});
5353

54+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
55+
return mlir::createFlattenCFGPass();
56+
});
57+
5458
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
5559
return mlir::createReconcileUnrealizedCastsPass();
5660
});

0 commit comments

Comments
 (0)