@@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final
46
46
// diagnostics are gone.
47
47
SourceLocation dirLoc;
48
48
49
- const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr ;
49
+ llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
50
+
51
+ void setLastDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
52
+ lastDeviceTypeValues.clear ();
53
+
54
+ llvm::for_each (clause.getArchitectures (),
55
+ [this ](const DeviceTypeArgument &arg) {
56
+ lastDeviceTypeValues.push_back (
57
+ decodeDeviceType (arg.getIdentifierInfo ()));
58
+ });
59
+ }
50
60
51
61
void clauseNotImplemented (const OpenACCClause &c) {
52
62
cgf.cgm .errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
@@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final
95
105
.CaseLower (" radeon" , mlir::acc::DeviceType::Radeon);
96
106
}
97
107
98
- // Overload of this function that only returns the device-types list.
99
- mlir::ArrayAttr
100
- handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes) {
101
- mlir::ValueRange argument;
102
- mlir::MutableOperandRange range{operation};
103
-
104
- return handleDeviceTypeAffectedClause (existingDeviceTypes, argument, range);
105
- }
106
- // Overload of this function for when 'segments' aren't necessary.
107
- mlir::ArrayAttr
108
- handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
109
- mlir::ValueRange argument,
110
- mlir::MutableOperandRange argCollection) {
111
- llvm::SmallVector<int32_t > segments;
112
- assert (argument.size () <= 1 &&
113
- " Overload only for cases where segments don't need to be added" );
114
- return handleDeviceTypeAffectedClause (existingDeviceTypes, argument,
115
- argCollection, segments);
116
- }
117
-
118
- // Handle a clause affected by the 'device_type' to the point that they need
119
- // to have attributes added in the correct/corresponding order, such as
120
- // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121
- // a collection of operands that need to be appended to the `argCollection` as
122
- // we're adding a 'device_type' entry. If there is more than 0 elements in
123
- // the 'argument', the collection must be non-null, as it is needed to add to
124
- // it.
125
- // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126
- // be maintained, this takes a list of segments that will be updated with the
127
- // proper counts as 'argument' elements are added.
128
- //
129
- // In MLIR, the 'operands' are stored as a large array, with a separate array
130
- // of 'segments' that show which 'operand' applies to which 'operand-kind'.
131
- // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
132
- //
133
- // So the operands array might have 4 elements, but the 'segments' array will
134
- // be something like:
135
- //
136
- // {0, 0, 0, 2, 0, 1, 1, 0, 0...}
137
- //
138
- // Where each position belongs to a specific 'operand-kind'. So that
139
- // specifies that whichever operand-kind corresponds with index '3' has 2
140
- // elements, and should take the 1st 2 operands off the list (since all
141
- // preceding values are 0). operand-kinds corresponding to 5 and 6 each have
142
- // 1 element.
143
- //
144
- // Fortunately, the `MutableOperandRange` append function actually takes care
145
- // of that for us at the 'top level'.
146
- //
147
- // However, in cases like `num_gangs' or 'wait', where each individual
148
- // 'element' might be itself array-like, there is a separate 'segments' array
149
- // for them. So in the case of:
150
- //
151
- // device_type(nvidia, radeon) num_gangs(1, 2, 3)
152
- //
153
- // We have to emit that as TWO arrays into the IR (where the device_type is an
154
- // attribute), so they look like:
155
- //
156
- // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
157
- // {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
158
- //
159
- // When stored in the 'operands' list, the top-level 'segment' for
160
- // 'num_gangs' just shows 6 elements. In order to get the array-like
161
- // apperance, the 'numGangsSegments' list is kept as well. In the above case,
162
- // we've inserted 6 operands, so the 'numGangsSegments' must contain 2
163
- // elements, 1 per array, and each will have a value of 3. The verifier will
164
- // ensure that the collections counts are correct.
165
- mlir::ArrayAttr
166
- handleDeviceTypeAffectedClause (mlir::ArrayAttr existingDeviceTypes,
167
- mlir::ValueRange argument,
168
- mlir::MutableOperandRange argCollection,
169
- llvm::SmallVector<int32_t > &segments) {
170
- llvm::SmallVector<mlir::Attribute> deviceTypes;
171
-
172
- // Collect the 'existing' device-type attributes so we can re-create them
173
- // and insert them.
174
- if (existingDeviceTypes) {
175
- for (const mlir::Attribute &Attr : existingDeviceTypes)
176
- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
177
- builder.getContext (),
178
- cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
179
- }
180
-
181
- // Insert 1 version of the 'expr' to the NumWorkers list per-current
182
- // device type.
183
- if (lastDeviceTypeClause) {
184
- for (const DeviceTypeArgument &arch :
185
- lastDeviceTypeClause->getArchitectures ()) {
186
- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
187
- builder.getContext (), decodeDeviceType (arch.getIdentifierInfo ())));
188
- if (!argument.empty ()) {
189
- argCollection.append (argument);
190
- segments.push_back (argument.size ());
191
- }
192
- }
193
- } else {
194
- // Else, we just add a single for 'none'.
195
- deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
196
- builder.getContext (), mlir::acc::DeviceType::None));
197
- if (!argument.empty ()) {
198
- argCollection.append (argument);
199
- segments.push_back (argument.size ());
200
- }
201
- }
202
-
203
- return mlir::ArrayAttr::get (builder.getContext (), deviceTypes);
204
- }
205
-
206
108
public:
207
109
OpenACCClauseCIREmitter (OpTy &operation, CIRGenFunction &cgf,
208
110
CIRGenBuilderTy &builder,
@@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final
236
138
}
237
139
238
140
void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
239
- lastDeviceTypeClause = &clause;
141
+ setLastDeviceTypeClause (clause);
142
+
240
143
if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
241
144
llvm::for_each (
242
145
clause.getArchitectures (), [this ](const DeviceTypeArgument &arg) {
@@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final
253
156
} else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp,
254
157
DataOp>) {
255
158
// Nothing to do here, these constructs don't have any IR for these, as
256
- // they just modify the other clauses IR. So setting of `lastDeviceType`
257
- // (done above) is all we need.
159
+ // they just modify the other clauses IR. So setting of
160
+ // `lastDeviceTypeValues` (done above) is all we need.
258
161
} else {
259
162
// TODO: When we've implemented this for everything, switch this to an
260
163
// unreachable. update, data, loop, routine, combined constructs remain.
@@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final
264
167
265
168
void VisitNumWorkersClause (const OpenACCNumWorkersClause &clause) {
266
169
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
267
- mlir::MutableOperandRange range = operation.getNumWorkersMutable ();
268
- operation.setNumWorkersDeviceTypeAttr (handleDeviceTypeAffectedClause (
269
- operation.getNumWorkersDeviceTypeAttr (),
270
- createIntExpr (clause.getIntExpr ()), range));
170
+ operation.addNumWorkersOperand (builder.getContext (),
171
+ createIntExpr (clause.getIntExpr ()),
172
+ lastDeviceTypeValues);
271
173
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
272
174
llvm_unreachable (" num_workers not valid on serial" );
273
175
} else {
@@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final
279
181
280
182
void VisitVectorLengthClause (const OpenACCVectorLengthClause &clause) {
281
183
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
282
- mlir::MutableOperandRange range = operation.getVectorLengthMutable ();
283
- operation.setVectorLengthDeviceTypeAttr (handleDeviceTypeAffectedClause (
284
- operation.getVectorLengthDeviceTypeAttr (),
285
- createIntExpr (clause.getIntExpr ()), range));
184
+ operation.addVectorLengthOperand (builder.getContext (),
185
+ createIntExpr (clause.getIntExpr ()),
186
+ lastDeviceTypeValues);
286
187
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
287
188
llvm_unreachable (" vector_length not valid on serial" );
288
189
} else {
@@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final
294
195
295
196
void VisitAsyncClause (const OpenACCAsyncClause &clause) {
296
197
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
297
- if (!clause.hasIntExpr ()) {
298
- operation.setAsyncOnlyAttr (
299
- handleDeviceTypeAffectedClause (operation.getAsyncOnlyAttr ()));
300
- } else {
301
- mlir::MutableOperandRange range = operation.getAsyncOperandsMutable ();
302
- operation.setAsyncOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
303
- operation.getAsyncOperandsDeviceTypeAttr (),
304
- createIntExpr (clause.getIntExpr ()), range));
305
- }
198
+ if (!clause.hasIntExpr ())
199
+ operation.addAsyncOnly (builder.getContext (), lastDeviceTypeValues);
200
+ else
201
+ operation.addAsyncOperand (builder.getContext (),
202
+ createIntExpr (clause.getIntExpr ()),
203
+ lastDeviceTypeValues);
306
204
} else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
307
205
// Wait doesn't have a device_type, so its handling here is slightly
308
206
// different.
@@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final
366
264
void VisitNumGangsClause (const OpenACCNumGangsClause &clause) {
367
265
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
368
266
llvm::SmallVector<mlir::Value> values;
369
-
370
267
for (const Expr *E : clause.getIntExprs ())
371
268
values.push_back (createIntExpr (E));
372
269
373
- llvm::SmallVector<int32_t > segments;
374
- if (operation.getNumGangsSegments ())
375
- llvm::copy (*operation.getNumGangsSegments (),
376
- std::back_inserter (segments));
377
-
378
- mlir::MutableOperandRange range = operation.getNumGangsMutable ();
379
- operation.setNumGangsDeviceTypeAttr (handleDeviceTypeAffectedClause (
380
- operation.getNumGangsDeviceTypeAttr (), values, range, segments));
381
- operation.setNumGangsSegments (llvm::ArrayRef<int32_t >{segments});
270
+ operation.addNumGangsOperands (builder.getContext (), values,
271
+ lastDeviceTypeValues);
382
272
} else {
383
273
// TODO: When we've implemented this for everything, switch this to an
384
274
// unreachable. Combined constructs remain.
@@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final
389
279
void VisitWaitClause (const OpenACCWaitClause &clause) {
390
280
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
391
281
if (!clause.hasExprs ()) {
392
- operation.setWaitOnlyAttr (
393
- handleDeviceTypeAffectedClause (operation.getWaitOnlyAttr ()));
282
+ operation.addWaitOnly (builder.getContext (), lastDeviceTypeValues);
394
283
} else {
395
284
llvm::SmallVector<mlir::Value> values;
396
-
397
285
if (clause.hasDevNumExpr ())
398
286
values.push_back (createIntExpr (clause.getDevNumExpr ()));
399
287
for (const Expr *E : clause.getQueueIdExprs ())
400
288
values.push_back (createIntExpr (E));
401
-
402
- llvm::SmallVector<int32_t > segments;
403
- if (operation.getWaitOperandsSegments ())
404
- llvm::copy (*operation.getWaitOperandsSegments (),
405
- std::back_inserter (segments));
406
-
407
- unsigned beforeSegmentSize = segments.size ();
408
-
409
- mlir::MutableOperandRange range = operation.getWaitOperandsMutable ();
410
- operation.setWaitOperandsDeviceTypeAttr (handleDeviceTypeAffectedClause (
411
- operation.getWaitOperandsDeviceTypeAttr (), values, range,
412
- segments));
413
- operation.setWaitOperandsSegments (segments);
414
-
415
- // In addition to having to set the 'segments', wait also has a list of
416
- // bool attributes whether it is annotated with 'devnum'. We can use
417
- // our knowledge of how much the 'segments' array grew to determine how
418
- // many we need to add.
419
- llvm::SmallVector<bool > hasDevNums;
420
- if (operation.getHasWaitDevnumAttr ())
421
- for (mlir::Attribute A : operation.getHasWaitDevnumAttr ())
422
- hasDevNums.push_back (cast<mlir::BoolAttr>(A).getValue ());
423
-
424
- hasDevNums.insert (hasDevNums.end (), segments.size () - beforeSegmentSize,
425
- clause.hasDevNumExpr ());
426
-
427
- operation.setHasWaitDevnumAttr (builder.getBoolArrayAttr (hasDevNums));
289
+ operation.addWaitOperands (builder.getContext (), clause.hasDevNumExpr (),
290
+ values, lastDeviceTypeValues);
428
291
}
429
292
} else {
430
293
// TODO: When we've implemented this for everything, switch this to an
@@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
589
452
if (s.hasDevNumExpr ())
590
453
waitOp.getWaitDevnumMutable ().append (createIntExpr (s.getDevNumExpr ()));
591
454
592
- for (Expr *QueueExpr : s.getQueueIdExprs ())
455
+ for (Expr *QueueExpr : s.getQueueIdExprs ())
593
456
waitOp.getWaitOperandsMutable ().append (createIntExpr (QueueExpr));
594
457
}
595
458
0 commit comments