@@ -46,7 +46,68 @@ using namespace Fortran::lower::omp;
46
46
// Code generation helper functions
47
47
// ===----------------------------------------------------------------------===//
48
48
49
- static bool evalHasSiblings (lower::pft::Evaluation &eval) {
49
+ // / Get the directive enumeration value corresponding to the given OpenMP
50
+ // / construct PFT node.
51
+ llvm::omp::Directive
52
+ extractOmpDirective (const parser::OpenMPConstruct &ompConstruct) {
53
+ return common::visit (
54
+ common::visitors{
55
+ [](const parser::OpenMPAllocatorsConstruct &c) {
56
+ return llvm::omp::OMPD_allocators;
57
+ },
58
+ [](const parser::OpenMPAtomicConstruct &c) {
59
+ return llvm::omp::OMPD_atomic;
60
+ },
61
+ [](const parser::OpenMPBlockConstruct &c) {
62
+ return std::get<parser::OmpBlockDirective>(
63
+ std::get<parser::OmpBeginBlockDirective>(c.t ).t )
64
+ .v ;
65
+ },
66
+ [](const parser::OpenMPCriticalConstruct &c) {
67
+ return llvm::omp::OMPD_critical;
68
+ },
69
+ [](const parser::OpenMPDeclarativeAllocate &c) {
70
+ return llvm::omp::OMPD_allocate;
71
+ },
72
+ [](const parser::OpenMPExecutableAllocate &c) {
73
+ return llvm::omp::OMPD_allocate;
74
+ },
75
+ [](const parser::OpenMPLoopConstruct &c) {
76
+ return std::get<parser::OmpLoopDirective>(
77
+ std::get<parser::OmpBeginLoopDirective>(c.t ).t )
78
+ .v ;
79
+ },
80
+ [](const parser::OpenMPSectionConstruct &c) {
81
+ return llvm::omp::OMPD_section;
82
+ },
83
+ [](const parser::OpenMPSectionsConstruct &c) {
84
+ return std::get<parser::OmpSectionsDirective>(
85
+ std::get<parser::OmpBeginSectionsDirective>(c.t ).t )
86
+ .v ;
87
+ },
88
+ [](const parser::OpenMPStandaloneConstruct &c) {
89
+ return common::visit (
90
+ common::visitors{
91
+ [](const parser::OpenMPSimpleStandaloneConstruct &c) {
92
+ return std::get<parser::OmpSimpleStandaloneDirective>(c.t )
93
+ .v ;
94
+ },
95
+ [](const parser::OpenMPFlushConstruct &c) {
96
+ return llvm::omp::OMPD_flush;
97
+ },
98
+ [](const parser::OpenMPCancelConstruct &c) {
99
+ return llvm::omp::OMPD_cancel;
100
+ },
101
+ [](const parser::OpenMPCancellationPointConstruct &c) {
102
+ return llvm::omp::OMPD_cancellation_point;
103
+ }},
104
+ c.u );
105
+ }},
106
+ ompConstruct.u );
107
+ }
108
+
109
+ // / Check whether the parent of the given evaluation contains other evaluations.
110
+ static bool evalHasSiblings (const lower::pft::Evaluation &eval) {
50
111
auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) {
51
112
for (auto &sibling : siblings)
52
113
if (&sibling != &eval && !sibling.isEndStmt ())
@@ -67,6 +128,80 @@ static bool evalHasSiblings(lower::pft::Evaluation &eval) {
67
128
}});
68
129
}
69
130
131
+ // / Check whether a given evaluation points to an OpenMP loop construct that
132
+ // / represents a target SPMD kernel. For this to be true, it must be a `target
133
+ // / teams distribute parallel do [simd]` or equivalent construct.
134
+ // /
135
+ // / Currently, this is limited to cases where all relevant OpenMP constructs are
136
+ // / either combined or directly nested within the same function. Also, the
137
+ // / composite `distribute parallel do` is not identified if split into two
138
+ // / explicit nested loops (a `distribute` loop and a `parallel do` loop).
139
+ static bool isTargetSPMDLoop (const lower::pft::Evaluation &eval) {
140
+ using namespace llvm ::omp;
141
+
142
+ const auto *ompEval = eval.getIf <parser::OpenMPConstruct>();
143
+ if (!ompEval)
144
+ return false ;
145
+
146
+ switch (extractOmpDirective (*ompEval)) {
147
+ case OMPD_distribute_parallel_do:
148
+ case OMPD_distribute_parallel_do_simd: {
149
+ // It will return true only if one of these are true:
150
+ // - It has a 'target teams' parent and no siblings.
151
+ // - It has a 'teams' parent and no siblings, and the 'teams' has a
152
+ // 'target' parent and no siblings.
153
+ if (evalHasSiblings (eval))
154
+ return false ;
155
+
156
+ const auto *parentEval = eval.parent .getIf <lower::pft::Evaluation>();
157
+ if (!parentEval)
158
+ return false ;
159
+
160
+ const auto *parentOmpEval = parentEval->getIf <parser::OpenMPConstruct>();
161
+ if (!parentOmpEval)
162
+ return false ;
163
+
164
+ auto parentDir = extractOmpDirective (*parentOmpEval);
165
+ if (parentDir == OMPD_target_teams)
166
+ return true ;
167
+
168
+ if (parentDir != OMPD_teams)
169
+ return false ;
170
+
171
+ if (evalHasSiblings (*parentEval))
172
+ return false ;
173
+
174
+ const auto *parentOfParentEval =
175
+ parentEval->parent .getIf <lower::pft::Evaluation>();
176
+ if (!parentEval)
177
+ return false ;
178
+
179
+ const auto *parentOfParentOmpEval =
180
+ parentOfParentEval->getIf <parser::OpenMPConstruct>();
181
+ return parentOfParentOmpEval &&
182
+ extractOmpDirective (*parentOfParentOmpEval) == OMPD_target;
183
+ }
184
+ case OMPD_teams_distribute_parallel_do:
185
+ case OMPD_teams_distribute_parallel_do_simd: {
186
+ // Check there's a 'target' parent and no siblings.
187
+ if (evalHasSiblings (eval))
188
+ return false ;
189
+
190
+ const auto *parentEval = eval.parent .getIf <lower::pft::Evaluation>();
191
+ if (!parentEval)
192
+ return false ;
193
+
194
+ const auto *parentOmpEval = parentEval->getIf <parser::OpenMPConstruct>();
195
+ return parentOmpEval && extractOmpDirective (*parentOmpEval) == OMPD_target;
196
+ }
197
+ case OMPD_target_teams_distribute_parallel_do:
198
+ case OMPD_target_teams_distribute_parallel_do_simd:
199
+ return true ;
200
+ default :
201
+ return false ;
202
+ }
203
+ }
204
+
70
205
static mlir::omp::TargetOp findParentTargetOp (mlir::OpBuilder &builder) {
71
206
mlir::Operation *parentOp = builder.getBlock ()->getParentOp ();
72
207
if (!parentOp)
@@ -113,8 +248,9 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
113
248
converter.genEval (e);
114
249
}
115
250
116
- static bool mustEvalTeamsThreadsOutsideTarget (lower::pft::Evaluation &eval,
117
- mlir::omp::TargetOp targetOp) {
251
+ static bool
252
+ mustEvalTeamsThreadsOutsideTarget (const lower::pft::Evaluation &eval,
253
+ mlir::omp::TargetOp targetOp) {
118
254
if (!targetOp)
119
255
return false ;
120
256
@@ -123,25 +259,8 @@ static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
123
259
if (offloadModOp.getIsTargetDevice ())
124
260
return false ;
125
261
126
- auto dir = Fortran::common::visit (
127
- common::visitors{
128
- [&](const parser::OpenMPBlockConstruct &c) {
129
- return std::get<parser::OmpBlockDirective>(
130
- std::get<parser::OmpBeginBlockDirective>(c.t ).t )
131
- .v ;
132
- },
133
- [&](const parser::OpenMPLoopConstruct &c) {
134
- return std::get<parser::OmpLoopDirective>(
135
- std::get<parser::OmpBeginLoopDirective>(c.t ).t )
136
- .v ;
137
- },
138
- [&](const auto &) {
139
- llvm_unreachable (" Unexpected OpenMP construct" );
140
- return llvm::omp::OMPD_unknown;
141
- },
142
- },
143
- eval.get <parser::OpenMPConstruct>().u );
144
-
262
+ llvm::omp::Directive dir =
263
+ extractOmpDirective (eval.get <parser::OpenMPConstruct>());
145
264
return llvm::omp::allTargetSet.test (dir) || !evalHasSiblings (eval);
146
265
}
147
266
@@ -1722,25 +1841,20 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1722
1841
firOpBuilder.getModule ().getOperation ());
1723
1842
auto targetOp = loopNestOp->getParentOfType <mlir::omp::TargetOp>();
1724
1843
1725
- if (offloadMod && targetOp && !offloadMod.getIsTargetDevice ()) {
1726
- if (targetOp.isTargetSPMDLoop ()) {
1727
- // Lower loop bounds and step, and process collapsing again, putting
1728
- // lowered values outside of omp.target this time. This enables
1729
- // calculating and accessing the trip count in the host, which is needed
1730
- // when lowering to LLVM IR via the OMPIRBuilder.
1731
- HostClausesInsertionGuard guard (firOpBuilder);
1732
- mlir::omp::LoopRelatedOps loopRelatedOps;
1733
- llvm::SmallVector<const semantics::Symbol *> iv;
1734
- ClauseProcessor cp (converter, semaCtx, item->clauses );
1735
- cp.processCollapse (loc, eval, loopRelatedOps, iv);
1736
- targetOp.getTripCountMutable ().assign (
1737
- calculateTripCount (converter.getFirOpBuilder (), loc, loopRelatedOps));
1738
- } else if (targetOp.getTripCountMutable ().size ()) {
1739
- // The MLIR target operation was updated during PFT lowering,
1740
- // and it is no longer an SPMD kernel. Erase the trip count because
1741
- // as it is now invalid.
1742
- targetOp.getTripCountMutable ().erase (0 );
1743
- }
1844
+ if (offloadMod && !offloadMod.getIsTargetDevice () && isTargetSPMDLoop (eval)) {
1845
+ assert (targetOp && " must have omp.target parent" );
1846
+
1847
+ // Lower loop bounds and step, and process collapsing again, putting lowered
1848
+ // values outside of omp.target this time. This enables calculating and
1849
+ // accessing the trip count in the host, which is needed when lowering to
1850
+ // LLVM IR via the OMPIRBuilder.
1851
+ HostClausesInsertionGuard guard (firOpBuilder);
1852
+ mlir::omp::LoopRelatedOps loopRelatedOps;
1853
+ llvm::SmallVector<const semantics::Symbol *> iv;
1854
+ ClauseProcessor cp (converter, semaCtx, item->clauses );
1855
+ cp.processCollapse (loc, eval, loopRelatedOps, iv);
1856
+ targetOp.getTripCountMutable ().assign (
1857
+ calculateTripCount (converter.getFirOpBuilder (), loc, loopRelatedOps));
1744
1858
}
1745
1859
return loopNestOp;
1746
1860
}
0 commit comments