Skip to content

Commit 4ea5257

Browse files
committed
[Flang][OpenMP] PFT-based detection of target SPMD
This patch improves the fix in llvm#125 to detect target SPMD kernels during Flang lowering to MLIR. It transitions from a MLIR-based check to a PFT-based check, which is a more resilient alternative since the MLIR representation is in process of being built where it's being checked.
1 parent 3d730dc commit 4ea5257

File tree

2 files changed

+346
-41
lines changed

2 files changed

+346
-41
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 155 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,68 @@ using namespace Fortran::lower::omp;
4646
// Code generation helper functions
4747
//===----------------------------------------------------------------------===//
4848

49-
static bool evalHasSiblings(lower::pft::Evaluation &eval) {
49+
/// Get the directive enumeration value corresponding to the given OpenMP
50+
/// construct PFT node.
51+
llvm::omp::Directive
52+
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
53+
return common::visit(
54+
common::visitors{
55+
[](const parser::OpenMPAllocatorsConstruct &c) {
56+
return llvm::omp::OMPD_allocators;
57+
},
58+
[](const parser::OpenMPAtomicConstruct &c) {
59+
return llvm::omp::OMPD_atomic;
60+
},
61+
[](const parser::OpenMPBlockConstruct &c) {
62+
return std::get<parser::OmpBlockDirective>(
63+
std::get<parser::OmpBeginBlockDirective>(c.t).t)
64+
.v;
65+
},
66+
[](const parser::OpenMPCriticalConstruct &c) {
67+
return llvm::omp::OMPD_critical;
68+
},
69+
[](const parser::OpenMPDeclarativeAllocate &c) {
70+
return llvm::omp::OMPD_allocate;
71+
},
72+
[](const parser::OpenMPExecutableAllocate &c) {
73+
return llvm::omp::OMPD_allocate;
74+
},
75+
[](const parser::OpenMPLoopConstruct &c) {
76+
return std::get<parser::OmpLoopDirective>(
77+
std::get<parser::OmpBeginLoopDirective>(c.t).t)
78+
.v;
79+
},
80+
[](const parser::OpenMPSectionConstruct &c) {
81+
return llvm::omp::OMPD_section;
82+
},
83+
[](const parser::OpenMPSectionsConstruct &c) {
84+
return std::get<parser::OmpSectionsDirective>(
85+
std::get<parser::OmpBeginSectionsDirective>(c.t).t)
86+
.v;
87+
},
88+
[](const parser::OpenMPStandaloneConstruct &c) {
89+
return common::visit(
90+
common::visitors{
91+
[](const parser::OpenMPSimpleStandaloneConstruct &c) {
92+
return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
93+
.v;
94+
},
95+
[](const parser::OpenMPFlushConstruct &c) {
96+
return llvm::omp::OMPD_flush;
97+
},
98+
[](const parser::OpenMPCancelConstruct &c) {
99+
return llvm::omp::OMPD_cancel;
100+
},
101+
[](const parser::OpenMPCancellationPointConstruct &c) {
102+
return llvm::omp::OMPD_cancellation_point;
103+
}},
104+
c.u);
105+
}},
106+
ompConstruct.u);
107+
}
108+
109+
/// Check whether the parent of the given evaluation contains other evaluations.
110+
static bool evalHasSiblings(const lower::pft::Evaluation &eval) {
50111
auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) {
51112
for (auto &sibling : siblings)
52113
if (&sibling != &eval && !sibling.isEndStmt())
@@ -67,6 +128,80 @@ static bool evalHasSiblings(lower::pft::Evaluation &eval) {
67128
}});
68129
}
69130

131+
/// Check whether a given evaluation points to an OpenMP loop construct that
132+
/// represents a target SPMD kernel. For this to be true, it must be a `target
133+
/// teams distribute parallel do [simd]` or equivalent construct.
134+
///
135+
/// Currently, this is limited to cases where all relevant OpenMP constructs are
136+
/// either combined or directly nested within the same function. Also, the
137+
/// composite `distribute parallel do` is not identified if split into two
138+
/// explicit nested loops (a `distribute` loop and a `parallel do` loop).
139+
static bool isTargetSPMDLoop(const lower::pft::Evaluation &eval) {
140+
using namespace llvm::omp;
141+
142+
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
143+
if (!ompEval)
144+
return false;
145+
146+
switch (extractOmpDirective(*ompEval)) {
147+
case OMPD_distribute_parallel_do:
148+
case OMPD_distribute_parallel_do_simd: {
149+
// It will return true only if one of these are true:
150+
// - It has a 'target teams' parent and no siblings.
151+
// - It has a 'teams' parent and no siblings, and the 'teams' has a
152+
// 'target' parent and no siblings.
153+
if (evalHasSiblings(eval))
154+
return false;
155+
156+
const auto *parentEval = eval.parent.getIf<lower::pft::Evaluation>();
157+
if (!parentEval)
158+
return false;
159+
160+
const auto *parentOmpEval = parentEval->getIf<parser::OpenMPConstruct>();
161+
if (!parentOmpEval)
162+
return false;
163+
164+
auto parentDir = extractOmpDirective(*parentOmpEval);
165+
if (parentDir == OMPD_target_teams)
166+
return true;
167+
168+
if (parentDir != OMPD_teams)
169+
return false;
170+
171+
if (evalHasSiblings(*parentEval))
172+
return false;
173+
174+
const auto *parentOfParentEval =
175+
parentEval->parent.getIf<lower::pft::Evaluation>();
176+
if (!parentEval)
177+
return false;
178+
179+
const auto *parentOfParentOmpEval =
180+
parentOfParentEval->getIf<parser::OpenMPConstruct>();
181+
return parentOfParentOmpEval &&
182+
extractOmpDirective(*parentOfParentOmpEval) == OMPD_target;
183+
}
184+
case OMPD_teams_distribute_parallel_do:
185+
case OMPD_teams_distribute_parallel_do_simd: {
186+
// Check there's a 'target' parent and no siblings.
187+
if (evalHasSiblings(eval))
188+
return false;
189+
190+
const auto *parentEval = eval.parent.getIf<lower::pft::Evaluation>();
191+
if (!parentEval)
192+
return false;
193+
194+
const auto *parentOmpEval = parentEval->getIf<parser::OpenMPConstruct>();
195+
return parentOmpEval && extractOmpDirective(*parentOmpEval) == OMPD_target;
196+
}
197+
case OMPD_target_teams_distribute_parallel_do:
198+
case OMPD_target_teams_distribute_parallel_do_simd:
199+
return true;
200+
default:
201+
return false;
202+
}
203+
}
204+
70205
static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) {
71206
mlir::Operation *parentOp = builder.getBlock()->getParentOp();
72207
if (!parentOp)
@@ -113,8 +248,9 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
113248
converter.genEval(e);
114249
}
115250

116-
static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
117-
mlir::omp::TargetOp targetOp) {
251+
static bool
252+
mustEvalTeamsThreadsOutsideTarget(const lower::pft::Evaluation &eval,
253+
mlir::omp::TargetOp targetOp) {
118254
if (!targetOp)
119255
return false;
120256

@@ -123,25 +259,8 @@ static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
123259
if (offloadModOp.getIsTargetDevice())
124260
return false;
125261

126-
auto dir = Fortran::common::visit(
127-
common::visitors{
128-
[&](const parser::OpenMPBlockConstruct &c) {
129-
return std::get<parser::OmpBlockDirective>(
130-
std::get<parser::OmpBeginBlockDirective>(c.t).t)
131-
.v;
132-
},
133-
[&](const parser::OpenMPLoopConstruct &c) {
134-
return std::get<parser::OmpLoopDirective>(
135-
std::get<parser::OmpBeginLoopDirective>(c.t).t)
136-
.v;
137-
},
138-
[&](const auto &) {
139-
llvm_unreachable("Unexpected OpenMP construct");
140-
return llvm::omp::OMPD_unknown;
141-
},
142-
},
143-
eval.get<parser::OpenMPConstruct>().u);
144-
262+
llvm::omp::Directive dir =
263+
extractOmpDirective(eval.get<parser::OpenMPConstruct>());
145264
return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval);
146265
}
147266

@@ -1722,25 +1841,20 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
17221841
firOpBuilder.getModule().getOperation());
17231842
auto targetOp = loopNestOp->getParentOfType<mlir::omp::TargetOp>();
17241843

1725-
if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) {
1726-
if (targetOp.isTargetSPMDLoop()) {
1727-
// Lower loop bounds and step, and process collapsing again, putting
1728-
// lowered values outside of omp.target this time. This enables
1729-
// calculating and accessing the trip count in the host, which is needed
1730-
// when lowering to LLVM IR via the OMPIRBuilder.
1731-
HostClausesInsertionGuard guard(firOpBuilder);
1732-
mlir::omp::LoopRelatedOps loopRelatedOps;
1733-
llvm::SmallVector<const semantics::Symbol *> iv;
1734-
ClauseProcessor cp(converter, semaCtx, item->clauses);
1735-
cp.processCollapse(loc, eval, loopRelatedOps, iv);
1736-
targetOp.getTripCountMutable().assign(
1737-
calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps));
1738-
} else if (targetOp.getTripCountMutable().size()) {
1739-
// The MLIR target operation was updated during PFT lowering,
1740-
// and it is no longer an SPMD kernel. Erase the trip count because
1741-
// as it is now invalid.
1742-
targetOp.getTripCountMutable().erase(0);
1743-
}
1844+
if (offloadMod && !offloadMod.getIsTargetDevice() && isTargetSPMDLoop(eval)) {
1845+
assert(targetOp && "must have omp.target parent");
1846+
1847+
// Lower loop bounds and step, and process collapsing again, putting lowered
1848+
// values outside of omp.target this time. This enables calculating and
1849+
// accessing the trip count in the host, which is needed when lowering to
1850+
// LLVM IR via the OMPIRBuilder.
1851+
HostClausesInsertionGuard guard(firOpBuilder);
1852+
mlir::omp::LoopRelatedOps loopRelatedOps;
1853+
llvm::SmallVector<const semantics::Symbol *> iv;
1854+
ClauseProcessor cp(converter, semaCtx, item->clauses);
1855+
cp.processCollapse(loc, eval, loopRelatedOps, iv);
1856+
targetOp.getTripCountMutable().assign(
1857+
calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps));
17441858
}
17451859
return loopNestOp;
17461860
}

0 commit comments

Comments
 (0)