Skip to content

Commit 22778d2

Browse files
sitio-coutolanza
authored andcommitted
[CIR] Yield boolean value in cir.loop condition region
Before this patch, the loop operation condition block yielded either empty or continue. This was replaced by a yield of a boolean value. This change simplifies both codegen and lowering, while also being semantically closer to the C language. It also refactors loop op codegen tests to validate only the lowering related to the cir.loop operation. Fixes llvm#161
1 parent a2c7584 commit 22778d2

File tree

12 files changed

+235
-357
lines changed

12 files changed

+235
-357
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,8 +1114,7 @@ def LoopOp : CIR_Op<"loop",
11141114
let description = [{
11151115
`cir.loop` represents C/C++ loop forms. It defines 3 blocks:
11161116
- `cond`: region can contain multiple blocks, terminated by regular
1117-
`cir.yield` when control should yield back to the parent, and
1118-
`cir.yield continue` when execution continues to another region.
1117+
`cir.yield %x` where `%x` is the boolean value to be evaluated.
11191118
The region destination depends on the loop form specified.
11201119
- `step`: region with one block, containing code to compute the
11211120
loop step, must be terminated with `cir.yield`.
@@ -1130,7 +1129,8 @@ def LoopOp : CIR_Op<"loop",
11301129
// i = i + 1;
11311130
// }
11321131
cir.loop while(cond : {
1133-
cir.yield continue
1132+
%2 = cir.const(#cir.bool<true>) : !cir.bool
1133+
cir.yield %2 : !cir.bool
11341134
}, step : {
11351135
cir.yield
11361136
}) {

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -652,26 +652,6 @@ mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S,
652652
return res;
653653
}
654654

655-
static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
656-
mlir::Location loc,
657-
mlir::Value cond) {
658-
mlir::Block *trueBB = nullptr, *falseBB = nullptr;
659-
{
660-
mlir::OpBuilder::InsertionGuard guard(builder);
661-
trueBB = builder.createBlock(builder.getBlock()->getParent());
662-
builder.create<mlir::cir::YieldOp>(loc, YieldOpKind::Continue);
663-
}
664-
{
665-
mlir::OpBuilder::InsertionGuard guard(builder);
666-
falseBB = builder.createBlock(builder.getBlock()->getParent());
667-
builder.create<mlir::cir::YieldOp>(loc);
668-
}
669-
670-
assert((trueBB && falseBB) && "expected both blocks to exist");
671-
builder.create<mlir::cir::BrCondOp>(loc, cond, trueBB, falseBB);
672-
return mlir::success();
673-
}
674-
675655
mlir::LogicalResult
676656
CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
677657
ArrayRef<const Attr *> ForAttrs) {
@@ -705,8 +685,7 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
705685
assert(!UnimplementedFeature::createProfileWeightsForLoop());
706686
assert(!UnimplementedFeature::emitCondLikelihoodViaExpectIntrinsic());
707687
mlir::Value condVal = evaluateExprAsBool(S.getCond());
708-
if (buildLoopCondYield(b, loc, condVal).failed())
709-
loopRes = mlir::failure();
688+
builder.create<mlir::cir::YieldOp>(loc, condVal);
710689
},
711690
/*bodyBuilder=*/
712691
[&](mlir::OpBuilder &b, mlir::Location loc) {
@@ -793,8 +772,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
793772
loc, boolTy,
794773
mlir::cir::BoolAttr::get(b.getContext(), boolTy, true));
795774
}
796-
if (buildLoopCondYield(b, loc, condVal).failed())
797-
loopRes = mlir::failure();
775+
builder.create<mlir::cir::YieldOp>(loc, condVal);
798776
},
799777
/*bodyBuilder=*/
800778
[&](mlir::OpBuilder &b, mlir::Location loc) {
@@ -862,8 +840,7 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
862840
// expression compares unequal to 0. The condition must be a
863841
// scalar type.
864842
mlir::Value condVal = evaluateExprAsBool(S.getCond());
865-
if (buildLoopCondYield(b, loc, condVal).failed())
866-
loopRes = mlir::failure();
843+
builder.create<mlir::cir::YieldOp>(loc, condVal);
867844
},
868845
/*bodyBuilder=*/
869846
[&](mlir::OpBuilder &b, mlir::Location loc) {
@@ -927,8 +904,7 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
927904
// expression compares unequal to 0. The condition must be a
928905
// scalar type.
929906
condVal = evaluateExprAsBool(S.getCond());
930-
if (buildLoopCondYield(b, loc, condVal).failed())
931-
loopRes = mlir::failure();
907+
builder.create<mlir::cir::YieldOp>(loc, condVal);
932908
},
933909
/*bodyBuilder=*/
934910
[&](mlir::OpBuilder &b, mlir::Location loc) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
1515
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
1616
#include "clang/CIR/Dialect/IR/CIRTypes.h"
17+
#include "llvm/ADT/SmallVector.h"
1718

1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -1098,12 +1099,9 @@ void LoopOp::build(OpBuilder &builder, OperationState &result,
10981099
/// operand is not a constant.
10991100
void LoopOp::getSuccessorRegions(mlir::RegionBranchPoint point,
11001101
SmallVectorImpl<RegionSuccessor> &regions) {
1101-
// If any index all the underlying regions branch back to the parent
1102-
// operation.
1103-
if (!point.isParent()) {
1104-
regions.push_back(RegionSuccessor());
1102+
// If any index, do nothing.
1103+
if (!point.isParent())
11051104
return;
1106-
}
11071105

11081106
// FIXME: we want to look at cond region for getting more accurate results
11091107
// if the other regions will get a chance to execute.
@@ -1115,26 +1113,29 @@ void LoopOp::getSuccessorRegions(mlir::RegionBranchPoint point,
11151113
llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }
11161114

11171115
LogicalResult LoopOp::verify() {
1118-
// Cond regions should only terminate with plain 'cir.yield' or
1119-
// 'cir.yield continue'.
1120-
auto terminateError = [&]() {
1121-
return emitOpError() << "cond region must be terminated with "
1122-
"'cir.yield' or 'cir.yield continue'";
1123-
};
11241116

1125-
auto &blocks = getCond().getBlocks();
1126-
for (Block &block : blocks) {
1127-
if (block.empty())
1128-
continue;
1129-
auto &op = block.back();
1130-
if (isa<BrCondOp>(op))
1131-
continue;
1132-
if (!isa<YieldOp>(op))
1133-
terminateError();
1134-
auto y = cast<YieldOp>(op);
1135-
if (!(y.isPlain() || y.isContinue()))
1136-
terminateError();
1137-
}
1117+
if (getCond().empty() || getStep().empty() || getBody().empty())
1118+
return emitOpError("regions must not be empty");
1119+
1120+
auto condYield = dyn_cast<YieldOp>(getCond().back().getTerminator());
1121+
auto stepYield = dyn_cast<YieldOp>(getStep().back().getTerminator());
1122+
1123+
if (!condYield || !stepYield)
1124+
return emitOpError(
1125+
"cond and step regions must be terminated with 'cir.yield'");
1126+
1127+
if (condYield.getNumOperands() != 1 ||
1128+
!condYield.getOperand(0).getType().isa<cir::BoolType>())
1129+
return emitOpError("cond region must yield a single boolean value");
1130+
1131+
if (stepYield.getNumOperands() != 0)
1132+
return emitOpError("step region should not yield values");
1133+
1134+
// Body may yield or return.
1135+
auto *bodyTerminator = getBody().back().getTerminator();
1136+
1137+
if (isa<YieldOp>(bodyTerminator) && bodyTerminator->getNumOperands() != 0)
1138+
return emitOpError("body region must not yield values");
11381139

11391140
return success();
11401141
}
@@ -1261,8 +1262,8 @@ void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
12611262

12621263
LogicalResult
12631264
GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1264-
// Verify that the result type underlying pointer type matches the type of the
1265-
// referenced cir.global or cir.func op.
1265+
// Verify that the result type underlying pointer type matches the type of
1266+
// the referenced cir.global or cir.func op.
12661267
auto op = symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
12671268
if (!(isa<GlobalOp>(op) || isa<FuncOp>(op)))
12681269
return emitOpError("'")
@@ -1296,8 +1297,8 @@ VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
12961297
return success();
12971298
auto name = *getName();
12981299

1299-
// Verify that the result type underlying pointer type matches the type of the
1300-
// referenced cir.global or cir.func op.
1300+
// Verify that the result type underlying pointer type matches the type of
1301+
// the referenced cir.global or cir.func op.
13011302
auto op = dyn_cast_or_null<GlobalOp>(
13021303
symbolTable.lookupNearestSymbolFrom(*this, getNameAttr()));
13031304
if (!op)
@@ -1555,7 +1556,6 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
15551556
getFunctionTypeAttrName(), getLinkageAttrName(), getBuiltinAttrName(),
15561557
getNoProtoAttrName(), getExtraAttrsAttrName()});
15571558

1558-
15591559
if (auto aliaseeName = getAliasee()) {
15601560
p << " alias(";
15611561
p.printSymbolName(*aliaseeName);
@@ -1785,7 +1785,8 @@ LogicalResult UnaryOp::verify() {
17851785
case cir::UnaryOpKind::Inc:
17861786
LLVM_FALLTHROUGH;
17871787
case cir::UnaryOpKind::Dec: {
1788-
// TODO: Consider looking at the memory interface instead of LoadOp/StoreOp.
1788+
// TODO: Consider looking at the memory interface instead of
1789+
// LoadOp/StoreOp.
17891790
auto loadOp = getInput().getDefiningOp<cir::LoadOp>();
17901791
if (!loadOp)
17911792
return emitOpError() << "requires input to be defined by a memory load";

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -118,27 +118,6 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
118118
using mlir::OpConversionPattern<mlir::cir::LoopOp>::OpConversionPattern;
119119
using LoopKind = mlir::cir::LoopOpKind;
120120

121-
mlir::LogicalResult
122-
fetchCondRegionYields(mlir::Region &condRegion,
123-
mlir::cir::YieldOp &yieldToBody,
124-
mlir::cir::YieldOp &yieldToCont) const {
125-
for (auto &bb : condRegion) {
126-
if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(bb.getTerminator())) {
127-
if (!yieldOp.getKind().has_value())
128-
yieldToCont = yieldOp;
129-
else if (yieldOp.getKind() == mlir::cir::YieldOpKind::Continue)
130-
yieldToBody = yieldOp;
131-
else
132-
return mlir::failure();
133-
}
134-
}
135-
136-
// Succeed only if both yields are found.
137-
if (!yieldToBody || !yieldToCont)
138-
return mlir::failure();
139-
return mlir::success();
140-
}
141-
142121
mlir::LogicalResult
143122
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
144123
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -150,9 +129,8 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
150129
// Fetch required info from the condition region.
151130
auto &condRegion = loopOp.getCond();
152131
auto &condFrontBlock = condRegion.front();
153-
mlir::cir::YieldOp yieldToBody, yieldToCont;
154-
if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed())
155-
return loopOp.emitError("failed to fetch yields in cond region");
132+
auto condYield =
133+
cast<mlir::cir::YieldOp>(condRegion.back().getTerminator());
156134

157135
// Fetch required info from the body region.
158136
auto &bodyRegion = loopOp.getBody();
@@ -165,7 +143,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
165143
auto &stepRegion = loopOp.getStep();
166144
auto &stepFrontBlock = stepRegion.front();
167145
auto stepYield =
168-
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
146+
cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
169147

170148
// Move loop op region contents to current CFG.
171149
rewriter.inlineRegionBefore(condRegion, continueBlock);
@@ -178,13 +156,10 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
178156
auto &entry = (kind != LoopKind::DoWhile ? condFrontBlock : bodyFrontBlock);
179157
rewriter.create<mlir::cir::BrOp>(loopOp.getLoc(), &entry);
180158

181-
// Set loop exit point to continue block.
182-
rewriter.setInsertionPoint(yieldToCont);
183-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToCont, continueBlock);
184-
185-
// Branch from condition to body.
186-
rewriter.setInsertionPoint(yieldToBody);
187-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToBody, &bodyFrontBlock);
159+
// Branch to body when true and to exit when false.
160+
rewriter.setInsertionPoint(condYield);
161+
rewriter.replaceOpWithNewOp<mlir::cir::BrCondOp>(
162+
condYield, condYield.getOperand(0), &bodyFrontBlock, continueBlock);
188163

189164
// Branch from body to condition or to step on for-loop cases.
190165
rewriter.setInsertionPoint(bodyYield);

clang/test/CIR/CodeGen/loop.cpp

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ void l0() {
88

99
// CHECK: cir.func @_Z2l0v
1010
// CHECK: cir.loop for(cond : {
11-
// CHECK-NEXT: cir.yield continue
11+
// CHECK-NEXT: %0 = cir.const(#true) : !cir.bool
12+
// CHECK-NEXT: cir.yield %0
1213
// CHECK-NEXT: }, step : {
1314
// CHECK-NEXT: cir.yield
1415
// CHECK-NEXT: }) {
@@ -27,11 +28,7 @@ void l1() {
2728
// CHECK-NEXT: %4 = cir.load %2 : cir.ptr <!s32i>, !s32i
2829
// CHECK-NEXT: %5 = cir.const(#cir.int<10> : !s32i) : !s32i
2930
// CHECK-NEXT: %6 = cir.cmp(lt, %4, %5) : !s32i, !cir.bool
30-
// CHECK-NEXT: cir.brcond %6 ^bb1, ^bb2
31-
// CHECK-NEXT: ^bb1:
32-
// CHECK-NEXT: cir.yield continue
33-
// CHECK-NEXT: ^bb2:
34-
// CHECK-NEXT: cir.yield
31+
// CHECK-NEXT: cir.yield %6 : !cir.bool
3532
// CHECK-NEXT: }, step : {
3633
// CHECK-NEXT: %4 = cir.load %2 : cir.ptr <!s32i>, !s32i
3734
// CHECK-NEXT: %5 = cir.const(#cir.int<1> : !s32i) : !s32i
@@ -62,12 +59,8 @@ void l2(bool cond) {
6259
// CHECK: cir.func @_Z2l2b
6360
// CHECK: cir.scope {
6461
// CHECK-NEXT: cir.loop while(cond : {
65-
// CHECK-NEXT: %3 = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
66-
// CHECK-NEXT: cir.brcond %3 ^bb1, ^bb2
67-
// CHECK-NEXT: ^bb1:
68-
// CHECK-NEXT: cir.yield continue
69-
// CHECK-NEXT: ^bb2:
70-
// CHECK-NEXT: cir.yield
62+
// CHECK-NEXT: %3 = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
63+
// CHECK-NEXT: cir.yield %3 : !cir.bool
7164
// CHECK-NEXT: }, step : {
7265
// CHECK-NEXT: cir.yield
7366
// CHECK-NEXT: }) {
@@ -80,7 +73,8 @@ void l2(bool cond) {
8073
// CHECK-NEXT: }
8174
// CHECK-NEXT: cir.scope {
8275
// CHECK-NEXT: cir.loop while(cond : {
83-
// CHECK-NEXT: cir.yield continue
76+
// CHECK-NEXT: %3 = cir.const(#true) : !cir.bool
77+
// CHECK-NEXT: cir.yield %3 : !cir.bool
8478
// CHECK-NEXT: }, step : {
8579
// CHECK-NEXT: cir.yield
8680
// CHECK-NEXT: }) {
@@ -93,13 +87,9 @@ void l2(bool cond) {
9387
// CHECK-NEXT: }
9488
// CHECK-NEXT: cir.scope {
9589
// CHECK-NEXT: cir.loop while(cond : {
96-
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
97-
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
98-
// CHECK-NEXT: cir.brcond %4 ^bb1, ^bb2
99-
// CHECK-NEXT: ^bb1:
100-
// CHECK-NEXT: cir.yield continue
101-
// CHECK-NEXT: ^bb2:
102-
// CHECK-NEXT: cir.yield
90+
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
91+
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
92+
// CHECK-NEXT: cir.yield %4 : !cir.bool
10393
// CHECK-NEXT: }, step : {
10494
// CHECK-NEXT: cir.yield
10595
// CHECK-NEXT: }) {
@@ -128,11 +118,7 @@ void l3(bool cond) {
128118
// CHECK: cir.scope {
129119
// CHECK-NEXT: cir.loop dowhile(cond : {
130120
// CHECK-NEXT: %3 = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
131-
// CHECK-NEXT: cir.brcond %3 ^bb1, ^bb2
132-
// CHECK-NEXT: ^bb1:
133-
// CHECK-NEXT: cir.yield continue
134-
// CHECK-NEXT: ^bb2:
135-
// CHECK-NEXT: cir.yield
121+
// CHECK-NEXT: cir.yield %3
136122
// CHECK-NEXT: }, step : {
137123
// CHECK-NEXT: cir.yield
138124
// CHECK-NEXT: }) {
@@ -145,7 +131,8 @@ void l3(bool cond) {
145131
// CHECK-NEXT: }
146132
// CHECK-NEXT: cir.scope {
147133
// CHECK-NEXT: cir.loop dowhile(cond : {
148-
// CHECK-NEXT: cir.yield continue
134+
// CHECK-NEXT: %3 = cir.const(#true) : !cir.bool
135+
// CHECK-NEXT: cir.yield %3 : !cir.bool
149136
// CHECK-NEXT: }, step : {
150137
// CHECK-NEXT: cir.yield
151138
// CHECK-NEXT: }) {
@@ -160,11 +147,7 @@ void l3(bool cond) {
160147
// CHECK-NEXT: cir.loop dowhile(cond : {
161148
// CHECK-NEXT: %3 = cir.const(#cir.int<1> : !s32i) : !s32i
162149
// CHECK-NEXT: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool
163-
// CHECK-NEXT: cir.brcond %4 ^bb1, ^bb2
164-
// CHECK-NEXT: ^bb1:
165-
// CHECK-NEXT: cir.yield continue
166-
// CHECK-NEXT: ^bb2:
167-
// CHECK-NEXT: cir.yield
150+
// CHECK-NEXT: cir.yield %4 : !cir.bool
168151
// CHECK-NEXT: }, step : {
169152
// CHECK-NEXT: cir.yield
170153
// CHECK-NEXT: }) {
@@ -188,7 +171,8 @@ void l4() {
188171

189172
// CHECK: cir.func @_Z2l4v
190173
// CHECK: cir.loop while(cond : {
191-
// CHECK-NEXT: cir.yield continue
174+
// CHECK-NEXT: %4 = cir.const(#true) : !cir.bool
175+
// CHECK-NEXT: cir.yield %4 : !cir.bool
192176
// CHECK-NEXT: }, step : {
193177
// CHECK-NEXT: cir.yield
194178
// CHECK-NEXT: }) {
@@ -215,11 +199,7 @@ void l5() {
215199
// CHECK-NEXT: cir.loop dowhile(cond : {
216200
// CHECK-NEXT: %0 = cir.const(#cir.int<0> : !s32i) : !s32i
217201
// CHECK-NEXT: %1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool
218-
// CHECK-NEXT: cir.brcond %1 ^bb1, ^bb2
219-
// CHECK-NEXT: ^bb1:
220-
// CHECK-NEXT: cir.yield continue
221-
// CHECK-NEXT: ^bb2:
222-
// CHECK-NEXT: cir.yield
202+
// CHECK-NEXT: cir.yield %1 : !cir.bool
223203
// CHECK-NEXT: }, step : {
224204
// CHECK-NEXT: cir.yield
225205
// CHECK-NEXT: }) {
@@ -238,7 +218,8 @@ void l6() {
238218
// CHECK: cir.func @_Z2l6v()
239219
// CHECK-NEXT: cir.scope {
240220
// CHECK-NEXT: cir.loop while(cond : {
241-
// CHECK-NEXT: cir.yield continue
221+
// CHECK-NEXT: %0 = cir.const(#true) : !cir.bool
222+
// CHECK-NEXT: cir.yield %0 : !cir.bool
242223
// CHECK-NEXT: }, step : {
243224
// CHECK-NEXT: cir.yield
244225
// CHECK-NEXT: }) {

0 commit comments

Comments
 (0)