Skip to content

Commit 9120562

Browse files
authored
[Clang][OpenMP] Enable tile/unroll on iterator- and foreach-loops (#91459)
OpenMP loop transformation did not work on a for-loop using an iterator or range-based for-loops. The first reason is that it combined the iterator's type for generated loops with the type of `NumIterations` as generated for any `OMPLoopBasedDirective` which is an integer. Fixed by basing all generated loop variables on `NumIterations`. Second, C++11 range-based for-loops include syntactic sugar that needs to be executed before the loop. This additional code is now added to the construct's Pre-Init lists. Third, C++20 added an initializer statement to range-based for-loops which is also added to the pre-init statement. PreInits used to be a `DeclStmt` which made it difficult to add arbitrary statements from `CXXRangeForStmt`'s syntactic sugar, especially the for-loops init statement which does not need to be a declaration. Change it to be a general `Stmt` that can be a `CompoundStmt` to hold arbitrary Stmts, including DeclStmts. This also avoids the `PointerUnion` workaround used by `checkTransformableLoopNest`. End-to-end tests are added to verify the expected number and order of loop execution and evaluations of expressions (such as iterator dereference). The order and number of evaluations of expressions in canonical loops is explicitly undefined by OpenMP but checked here for clarification and for changes to be noticed.
1 parent 71b1fbd commit 9120562

18 files changed

+2511
-447
lines changed

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,9 +1390,7 @@ class SemaOpenMP : public SemaBase {
13901390
bool checkTransformableLoopNest(
13911391
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
13921392
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
1393-
Stmt *&Body,
1394-
SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
1395-
&OriginalInits);
1393+
Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits);
13961394

13971395
/// Helper to keep information about the current `omp begin/end declare
13981396
/// variant` nesting.

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class OMPTeamsScope final : public OMPLexicalScope {
142142
/// of used expression from loop statement.
143143
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
144144
void emitPreInitStmt(CodeGenFunction &CGF, const OMPLoopBasedDirective &S) {
145-
const DeclStmt *PreInits;
145+
const Stmt *PreInits;
146146
CodeGenFunction::OMPMapVars PreCondVars;
147147
if (auto *LD = dyn_cast<OMPLoopDirective>(&S)) {
148148
llvm::DenseSet<const VarDecl *> EmittedAsPrivate;
@@ -182,17 +182,34 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
182182
}
183183
return false;
184184
});
185-
PreInits = cast_or_null<DeclStmt>(LD->getPreInits());
185+
PreInits = LD->getPreInits();
186186
} else if (const auto *Tile = dyn_cast<OMPTileDirective>(&S)) {
187-
PreInits = cast_or_null<DeclStmt>(Tile->getPreInits());
187+
PreInits = Tile->getPreInits();
188188
} else if (const auto *Unroll = dyn_cast<OMPUnrollDirective>(&S)) {
189-
PreInits = cast_or_null<DeclStmt>(Unroll->getPreInits());
189+
PreInits = Unroll->getPreInits();
190190
} else {
191191
llvm_unreachable("Unknown loop-based directive kind.");
192192
}
193193
if (PreInits) {
194-
for (const auto *I : PreInits->decls())
195-
CGF.EmitVarDecl(cast<VarDecl>(*I));
194+
// CompoundStmts and DeclStmts are used as lists of PreInit statements and
195+
// declarations. Since declarations must be visible in the the following
196+
// that they initialize, unpack the ComboundStmt they are nested in.
197+
SmallVector<const Stmt *> PreInitStmts;
198+
if (auto *PreInitCompound = dyn_cast<CompoundStmt>(PreInits))
199+
llvm::append_range(PreInitStmts, PreInitCompound->body());
200+
else
201+
PreInitStmts.push_back(PreInits);
202+
203+
for (const Stmt *S : PreInitStmts) {
204+
// EmitStmt skips any OMPCapturedExprDecls, but needs to be emitted
205+
// here.
206+
if (auto *PreInitDecl = dyn_cast<DeclStmt>(S)) {
207+
for (Decl *I : PreInitDecl->decls())
208+
CGF.EmitVarDecl(cast<VarDecl>(*I));
209+
continue;
210+
}
211+
CGF.EmitStmt(S);
212+
}
196213
}
197214
PreCondVars.restore(CGF);
198215
}

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 140 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9815,6 +9815,25 @@ static Stmt *buildPreInits(ASTContext &Context,
98159815
return nullptr;
98169816
}
98179817

9818+
/// Append the \p Item or the content of a CompoundStmt to the list \p
9819+
/// TargetList.
9820+
///
9821+
/// A CompoundStmt is used as container in case multiple statements need to be
9822+
/// stored in lieu of using an explicit list. Flattening is necessary because
9823+
/// contained DeclStmts need to be visible after the execution of the list. Used
9824+
/// for OpenMP pre-init declarations/statements.
9825+
static void appendFlattendedStmtList(SmallVectorImpl<Stmt *> &TargetList,
9826+
Stmt *Item) {
9827+
// nullptr represents an empty list.
9828+
if (!Item)
9829+
return;
9830+
9831+
if (auto *CS = dyn_cast<CompoundStmt>(Item))
9832+
llvm::append_range(TargetList, CS->body());
9833+
else
9834+
TargetList.push_back(Item);
9835+
}
9836+
98189837
/// Build preinits statement for the given declarations.
98199838
static Stmt *
98209839
buildPreInits(ASTContext &Context,
@@ -9828,6 +9847,17 @@ buildPreInits(ASTContext &Context,
98289847
return nullptr;
98299848
}
98309849

9850+
/// Build pre-init statement for the given statements.
9851+
static Stmt *buildPreInits(ASTContext &Context, ArrayRef<Stmt *> PreInits) {
9852+
if (PreInits.empty())
9853+
return nullptr;
9854+
9855+
SmallVector<Stmt *> Stmts;
9856+
for (Stmt *S : PreInits)
9857+
appendFlattendedStmtList(Stmts, S);
9858+
return CompoundStmt::Create(Context, PreInits, FPOptionsOverride(), {}, {});
9859+
}
9860+
98319861
/// Build postupdate expression for the given list of postupdates expressions.
98329862
static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
98339863
Expr *PostUpdate = nullptr;
@@ -9924,11 +9954,21 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
99249954
Stmt *DependentPreInits = Transform->getPreInits();
99259955
if (!DependentPreInits)
99269956
return;
9927-
for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
9928-
auto *D = cast<VarDecl>(C);
9929-
DeclRefExpr *Ref = buildDeclRefExpr(SemaRef, D, D->getType(),
9930-
Transform->getBeginLoc());
9931-
Captures[Ref] = Ref;
9957+
9958+
// Search for pre-init declared variables that need to be captured
9959+
// to be referenceable inside the directive.
9960+
SmallVector<Stmt *> Constituents;
9961+
appendFlattendedStmtList(Constituents, DependentPreInits);
9962+
for (Stmt *S : Constituents) {
9963+
if (auto *DC = dyn_cast<DeclStmt>(S)) {
9964+
for (Decl *C : DC->decls()) {
9965+
auto *D = cast<VarDecl>(C);
9966+
DeclRefExpr *Ref = buildDeclRefExpr(
9967+
SemaRef, D, D->getType().getNonReferenceType(),
9968+
Transform->getBeginLoc());
9969+
Captures[Ref] = Ref;
9970+
}
9971+
}
99329972
}
99339973
}))
99349974
return 0;
@@ -15059,9 +15099,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
1505915099
bool SemaOpenMP::checkTransformableLoopNest(
1506015100
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
1506115101
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
15062-
Stmt *&Body,
15063-
SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
15064-
&OriginalInits) {
15102+
Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) {
1506515103
OriginalInits.emplace_back();
1506615104
bool Result = OMPLoopBasedDirective::doForAllLoops(
1506715105
AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops,
@@ -15095,16 +15133,70 @@ bool SemaOpenMP::checkTransformableLoopNest(
1509515133
DependentPreInits = Dir->getPreInits();
1509615134
else
1509715135
llvm_unreachable("Unhandled loop transformation");
15098-
if (!DependentPreInits)
15099-
return;
15100-
llvm::append_range(OriginalInits.back(),
15101-
cast<DeclStmt>(DependentPreInits)->getDeclGroup());
15136+
15137+
appendFlattendedStmtList(OriginalInits.back(), DependentPreInits);
1510215138
});
1510315139
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
1510415140
OriginalInits.pop_back();
1510515141
return Result;
1510615142
}
1510715143

15144+
/// Add preinit statements that need to be propageted from the selected loop.
15145+
static void addLoopPreInits(ASTContext &Context,
15146+
OMPLoopBasedDirective::HelperExprs &LoopHelper,
15147+
Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
15148+
SmallVectorImpl<Stmt *> &PreInits) {
15149+
15150+
// For range-based for-statements, ensure that their syntactic sugar is
15151+
// executed by adding them as pre-init statements.
15152+
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) {
15153+
Stmt *RangeInit = CXXRangeFor->getInit();
15154+
if (RangeInit)
15155+
PreInits.push_back(RangeInit);
15156+
15157+
DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt();
15158+
PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(),
15159+
RangeStmt->getBeginLoc(),
15160+
RangeStmt->getEndLoc()));
15161+
15162+
DeclStmt *RangeEnd = CXXRangeFor->getEndStmt();
15163+
PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(),
15164+
RangeEnd->getBeginLoc(),
15165+
RangeEnd->getEndLoc()));
15166+
}
15167+
15168+
llvm::append_range(PreInits, OriginalInit);
15169+
15170+
// List of OMPCapturedExprDecl, for __begin, __end, and NumIterations
15171+
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) {
15172+
PreInits.push_back(new (Context) DeclStmt(
15173+
PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc()));
15174+
}
15175+
15176+
// Gather declarations for the data members used as counters.
15177+
for (Expr *CounterRef : LoopHelper.Counters) {
15178+
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15179+
if (isa<OMPCapturedExprDecl>(CounterDecl))
15180+
PreInits.push_back(new (Context) DeclStmt(
15181+
DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation()));
15182+
}
15183+
}
15184+
15185+
/// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected
15186+
/// loop of a construct.
15187+
static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) {
15188+
size_t NumLoops = LoopStmts.size();
15189+
OMPLoopBasedDirective::doForAllLoops(
15190+
AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops,
15191+
[LoopStmts](unsigned Cnt, Stmt *CurStmt) {
15192+
assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned");
15193+
LoopStmts[Cnt] = CurStmt;
15194+
return false;
15195+
});
15196+
assert(!is_contained(LoopStmts, nullptr) &&
15197+
"Expecting a loop statement for each affected loop");
15198+
}
15199+
1510815200
StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1510915201
Stmt *AStmt,
1511015202
SourceLocation StartLoc,
@@ -15126,8 +15218,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1512615218
// Verify and diagnose loop nest.
1512715219
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
1512815220
Stmt *Body = nullptr;
15129-
SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, 4>
15130-
OriginalInits;
15221+
SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
1513115222
if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body,
1513215223
OriginalInits))
1513315224
return StmtError();
@@ -15144,7 +15235,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1514415235
"Expecting loop iteration space dimensionality to match number of "
1514515236
"affected loops");
1514615237

15147-
SmallVector<Decl *, 4> PreInits;
15238+
// Collect all affected loop statements.
15239+
SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
15240+
collectLoopStmts(AStmt, LoopStmts);
15241+
15242+
SmallVector<Stmt *, 4> PreInits;
1514815243
CaptureVars CopyTransformer(SemaRef);
1514915244

1515015245
// Create iteration variables for the generated loops.
@@ -15184,20 +15279,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1518415279
&SemaRef.PP.getIdentifierTable().get(TileCntName));
1518515280
TileIndVars[I] = TileCntDecl;
1518615281
}
15187-
for (auto &P : OriginalInits[I]) {
15188-
if (auto *D = P.dyn_cast<Decl *>())
15189-
PreInits.push_back(D);
15190-
else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15191-
PreInits.append(PI->decl_begin(), PI->decl_end());
15192-
}
15193-
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15194-
PreInits.append(PI->decl_begin(), PI->decl_end());
15195-
// Gather declarations for the data members used as counters.
15196-
for (Expr *CounterRef : LoopHelper.Counters) {
15197-
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15198-
if (isa<OMPCapturedExprDecl>(CounterDecl))
15199-
PreInits.push_back(CounterDecl);
15200-
}
15282+
15283+
addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
15284+
PreInits);
1520115285
}
1520215286

1520315287
// Once the original iteration values are set, append the innermost body.
@@ -15246,19 +15330,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1524615330
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
1524715331
Expr *NumIterations = LoopHelper.NumIterations;
1524815332
auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15249-
QualType CntTy = OrigCntVar->getType();
15333+
QualType IVTy = NumIterations->getType();
15334+
Stmt *LoopStmt = LoopStmts[I];
1525015335

1525115336
// Commonly used variables. One of the constraints of an AST is that every
1525215337
// node object must appear at most once, hence we define lamdas that create
1525315338
// a new AST node at every use.
15254-
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy,
15339+
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, IVTy,
1525515340
OrigCntVar]() {
15256-
return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy,
15341+
return buildDeclRefExpr(SemaRef, TileIndVars[I], IVTy,
1525715342
OrigCntVar->getExprLoc());
1525815343
};
15259-
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
15344+
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
1526015345
OrigCntVar]() {
15261-
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
15346+
return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy,
1526215347
OrigCntVar->getExprLoc());
1526315348
};
1526415349

@@ -15320,6 +15405,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1532015405
// further into the inner loop.
1532115406
SmallVector<Stmt *, 4> BodyParts;
1532215407
BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15408+
if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15409+
BodyParts.push_back(SourceCXXFor->getLoopVarStmt());
1532315410
BodyParts.push_back(Inner);
1532415411
Inner = CompoundStmt::Create(Context, BodyParts, FPOptionsOverride(),
1532515412
Inner->getBeginLoc(), Inner->getEndLoc());
@@ -15334,12 +15421,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1533415421
auto &LoopHelper = LoopHelpers[I];
1533515422
Expr *NumIterations = LoopHelper.NumIterations;
1533615423
DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15337-
QualType CntTy = OrigCntVar->getType();
15424+
QualType IVTy = NumIterations->getType();
1533815425

15339-
// Commonly used variables.
15340-
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
15426+
// Commonly used variables. One of the constraints of an AST is that every
15427+
// node object must appear at most once, hence we define lamdas that create
15428+
// a new AST node at every use.
15429+
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
1534115430
OrigCntVar]() {
15342-
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
15431+
return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy,
1534315432
OrigCntVar->getExprLoc());
1534415433
};
1534515434

@@ -15405,8 +15494,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1540515494
Stmt *Body = nullptr;
1540615495
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
1540715496
NumLoops);
15408-
SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, NumLoops + 1>
15409-
OriginalInits;
15497+
SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
1541015498
if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers,
1541115499
Body, OriginalInits))
1541215500
return StmtError();
@@ -15418,6 +15506,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1541815506
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
1541915507
NumGeneratedLoops, nullptr, nullptr);
1542015508

15509+
assert(LoopHelpers.size() == NumLoops &&
15510+
"Expecting a single-dimensional loop iteration space");
15511+
assert(OriginalInits.size() == NumLoops &&
15512+
"Expecting a single-dimensional loop iteration space");
1542115513
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
1542215514

1542315515
if (FullClause) {
@@ -15481,24 +15573,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1548115573
// of a canonical loop nest where these PreInits are emitted before the
1548215574
// outermost directive.
1548315575

15576+
// Find the loop statement.
15577+
Stmt *LoopStmt = nullptr;
15578+
collectLoopStmts(AStmt, {LoopStmt});
15579+
1548415580
// Determine the PreInit declarations.
15485-
SmallVector<Decl *, 4> PreInits;
15486-
assert(OriginalInits.size() == 1 &&
15487-
"Expecting a single-dimensional loop iteration space");
15488-
for (auto &P : OriginalInits[0]) {
15489-
if (auto *D = P.dyn_cast<Decl *>())
15490-
PreInits.push_back(D);
15491-
else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15492-
PreInits.append(PI->decl_begin(), PI->decl_end());
15493-
}
15494-
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15495-
PreInits.append(PI->decl_begin(), PI->decl_end());
15496-
// Gather declarations for the data members used as counters.
15497-
for (Expr *CounterRef : LoopHelper.Counters) {
15498-
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15499-
if (isa<OMPCapturedExprDecl>(CounterDecl))
15500-
PreInits.push_back(CounterDecl);
15501-
}
15581+
SmallVector<Stmt *, 4> PreInits;
15582+
addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
1550215583

1550315584
auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
1550415585
QualType IVTy = IterationVarRef->getType();
@@ -15604,6 +15685,8 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1560415685
// Inner For statement.
1560515686
SmallVector<Stmt *> InnerBodyStmts;
1560615687
InnerBodyStmts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15688+
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15689+
InnerBodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
1560715690
InnerBodyStmts.push_back(Body);
1560815691
CompoundStmt *InnerBody =
1560915692
CompoundStmt::Create(getASTContext(), InnerBodyStmts, FPOptionsOverride(),

0 commit comments

Comments
 (0)