Skip to content

Commit abfb2ce

Browse files
[OpenACC][NFCI] Implement 'helpers' for all of the clauses I've used so far (llvm#137396)
As a follow up to 3c4dff3 I audited all uses of 'process clause and use additive methods', and added explicit functions to the construct to make it easier for the next project to attempt to use this mechanism (vs construct all operands/etc in advance, then add all at once). I've only done ones that I have attempted to use so far(as a catch-up, so no var-list clauses, and no constructs that can't be used without a var-list, and no loop, and no compound constructs). I intend to do those "as I go" with the lowering of each of those things instead. --------- Co-authored-by: Andy Kaylor <[email protected]>
1 parent 98b895d commit abfb2ce

File tree

3 files changed

+402
-170
lines changed

3 files changed

+402
-170
lines changed

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 33 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final
4646
// diagnostics are gone.
4747
SourceLocation dirLoc;
4848

49-
const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr;
49+
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
50+
51+
void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
52+
lastDeviceTypeValues.clear();
53+
54+
llvm::for_each(clause.getArchitectures(),
55+
[this](const DeviceTypeArgument &arg) {
56+
lastDeviceTypeValues.push_back(
57+
decodeDeviceType(arg.getIdentifierInfo()));
58+
});
59+
}
5060

5161
void clauseNotImplemented(const OpenACCClause &c) {
5262
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
@@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final
95105
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
96106
}
97107

98-
// Overload of this function that only returns the device-types list.
99-
mlir::ArrayAttr
100-
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) {
101-
mlir::ValueRange argument;
102-
mlir::MutableOperandRange range{operation};
103-
104-
return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range);
105-
}
106-
// Overload of this function for when 'segments' aren't necessary.
107-
mlir::ArrayAttr
108-
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
109-
mlir::ValueRange argument,
110-
mlir::MutableOperandRange argCollection) {
111-
llvm::SmallVector<int32_t> segments;
112-
assert(argument.size() <= 1 &&
113-
"Overload only for cases where segments don't need to be added");
114-
return handleDeviceTypeAffectedClause(existingDeviceTypes, argument,
115-
argCollection, segments);
116-
}
117-
118-
// Handle a clause affected by the 'device_type' to the point that they need
119-
// to have attributes added in the correct/corresponding order, such as
120-
// 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
121-
// a collection of operands that need to be appended to the `argCollection` as
122-
// we're adding a 'device_type' entry. If there is more than 0 elements in
123-
// the 'argument', the collection must be non-null, as it is needed to add to
124-
// it.
125-
// As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to
126-
// be maintained, this takes a list of segments that will be updated with the
127-
// proper counts as 'argument' elements are added.
128-
//
129-
// In MLIR, the 'operands' are stored as a large array, with a separate array
130-
// of 'segments' that show which 'operand' applies to which 'operand-kind'.
131-
// That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
132-
//
133-
// So the operands array might have 4 elements, but the 'segments' array will
134-
// be something like:
135-
//
136-
// {0, 0, 0, 2, 0, 1, 1, 0, 0...}
137-
//
138-
// Where each position belongs to a specific 'operand-kind'. So that
139-
// specifies that whichever operand-kind corresponds with index '3' has 2
140-
// elements, and should take the 1st 2 operands off the list (since all
141-
// preceding values are 0). operand-kinds corresponding to 5 and 6 each have
142-
// 1 element.
143-
//
144-
// Fortunately, the `MutableOperandRange` append function actually takes care
145-
// of that for us at the 'top level'.
146-
//
147-
// However, in cases like `num_gangs' or 'wait', where each individual
148-
// 'element' might be itself array-like, there is a separate 'segments' array
149-
// for them. So in the case of:
150-
//
151-
// device_type(nvidia, radeon) num_gangs(1, 2, 3)
152-
//
153-
// We have to emit that as TWO arrays into the IR (where the device_type is an
154-
// attribute), so they look like:
155-
//
156-
// num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
157-
// {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
158-
//
159-
// When stored in the 'operands' list, the top-level 'segment' for
160-
// 'num_gangs' just shows 6 elements. In order to get the array-like
161-
// apperance, the 'numGangsSegments' list is kept as well. In the above case,
162-
// we've inserted 6 operands, so the 'numGangsSegments' must contain 2
163-
// elements, 1 per array, and each will have a value of 3. The verifier will
164-
// ensure that the collections counts are correct.
165-
mlir::ArrayAttr
166-
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
167-
mlir::ValueRange argument,
168-
mlir::MutableOperandRange argCollection,
169-
llvm::SmallVector<int32_t> &segments) {
170-
llvm::SmallVector<mlir::Attribute> deviceTypes;
171-
172-
// Collect the 'existing' device-type attributes so we can re-create them
173-
// and insert them.
174-
if (existingDeviceTypes) {
175-
for (const mlir::Attribute &Attr : existingDeviceTypes)
176-
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
177-
builder.getContext(),
178-
cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
179-
}
180-
181-
// Insert 1 version of the 'expr' to the NumWorkers list per-current
182-
// device type.
183-
if (lastDeviceTypeClause) {
184-
for (const DeviceTypeArgument &arch :
185-
lastDeviceTypeClause->getArchitectures()) {
186-
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
187-
builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
188-
if (!argument.empty()) {
189-
argCollection.append(argument);
190-
segments.push_back(argument.size());
191-
}
192-
}
193-
} else {
194-
// Else, we just add a single for 'none'.
195-
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
196-
builder.getContext(), mlir::acc::DeviceType::None));
197-
if (!argument.empty()) {
198-
argCollection.append(argument);
199-
segments.push_back(argument.size());
200-
}
201-
}
202-
203-
return mlir::ArrayAttr::get(builder.getContext(), deviceTypes);
204-
}
205-
206108
public:
207109
OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf,
208110
CIRGenBuilderTy &builder,
@@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final
236138
}
237139

238140
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
239-
lastDeviceTypeClause = &clause;
141+
setLastDeviceTypeClause(clause);
142+
240143
if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
241144
llvm::for_each(
242145
clause.getArchitectures(), [this](const DeviceTypeArgument &arg) {
@@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final
253156
} else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp,
254157
DataOp>) {
255158
// Nothing to do here, these constructs don't have any IR for these, as
256-
// they just modify the other clauses IR. So setting of `lastDeviceType`
257-
// (done above) is all we need.
159+
// they just modify the other clauses IR. So setting of
160+
// `lastDeviceTypeValues` (done above) is all we need.
258161
} else {
259162
// TODO: When we've implemented this for everything, switch this to an
260163
// unreachable. update, data, loop, routine, combined constructs remain.
@@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final
264167

265168
void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
266169
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
267-
mlir::MutableOperandRange range = operation.getNumWorkersMutable();
268-
operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
269-
operation.getNumWorkersDeviceTypeAttr(),
270-
createIntExpr(clause.getIntExpr()), range));
170+
operation.addNumWorkersOperand(builder.getContext(),
171+
createIntExpr(clause.getIntExpr()),
172+
lastDeviceTypeValues);
271173
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
272174
llvm_unreachable("num_workers not valid on serial");
273175
} else {
@@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final
279181

280182
void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
281183
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
282-
mlir::MutableOperandRange range = operation.getVectorLengthMutable();
283-
operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
284-
operation.getVectorLengthDeviceTypeAttr(),
285-
createIntExpr(clause.getIntExpr()), range));
184+
operation.addVectorLengthOperand(builder.getContext(),
185+
createIntExpr(clause.getIntExpr()),
186+
lastDeviceTypeValues);
286187
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
287188
llvm_unreachable("vector_length not valid on serial");
288189
} else {
@@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final
294195

295196
void VisitAsyncClause(const OpenACCAsyncClause &clause) {
296197
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
297-
if (!clause.hasIntExpr()) {
298-
operation.setAsyncOnlyAttr(
299-
handleDeviceTypeAffectedClause(operation.getAsyncOnlyAttr()));
300-
} else {
301-
mlir::MutableOperandRange range = operation.getAsyncOperandsMutable();
302-
operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
303-
operation.getAsyncOperandsDeviceTypeAttr(),
304-
createIntExpr(clause.getIntExpr()), range));
305-
}
198+
if (!clause.hasIntExpr())
199+
operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
200+
else
201+
operation.addAsyncOperand(builder.getContext(),
202+
createIntExpr(clause.getIntExpr()),
203+
lastDeviceTypeValues);
306204
} else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
307205
// Wait doesn't have a device_type, so its handling here is slightly
308206
// different.
@@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final
366264
void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
367265
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
368266
llvm::SmallVector<mlir::Value> values;
369-
370267
for (const Expr *E : clause.getIntExprs())
371268
values.push_back(createIntExpr(E));
372269

373-
llvm::SmallVector<int32_t> segments;
374-
if (operation.getNumGangsSegments())
375-
llvm::copy(*operation.getNumGangsSegments(),
376-
std::back_inserter(segments));
377-
378-
mlir::MutableOperandRange range = operation.getNumGangsMutable();
379-
operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause(
380-
operation.getNumGangsDeviceTypeAttr(), values, range, segments));
381-
operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments});
270+
operation.addNumGangsOperands(builder.getContext(), values,
271+
lastDeviceTypeValues);
382272
} else {
383273
// TODO: When we've implemented this for everything, switch this to an
384274
// unreachable. Combined constructs remain.
@@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final
389279
void VisitWaitClause(const OpenACCWaitClause &clause) {
390280
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) {
391281
if (!clause.hasExprs()) {
392-
operation.setWaitOnlyAttr(
393-
handleDeviceTypeAffectedClause(operation.getWaitOnlyAttr()));
282+
operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues);
394283
} else {
395284
llvm::SmallVector<mlir::Value> values;
396-
397285
if (clause.hasDevNumExpr())
398286
values.push_back(createIntExpr(clause.getDevNumExpr()));
399287
for (const Expr *E : clause.getQueueIdExprs())
400288
values.push_back(createIntExpr(E));
401-
402-
llvm::SmallVector<int32_t> segments;
403-
if (operation.getWaitOperandsSegments())
404-
llvm::copy(*operation.getWaitOperandsSegments(),
405-
std::back_inserter(segments));
406-
407-
unsigned beforeSegmentSize = segments.size();
408-
409-
mlir::MutableOperandRange range = operation.getWaitOperandsMutable();
410-
operation.setWaitOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
411-
operation.getWaitOperandsDeviceTypeAttr(), values, range,
412-
segments));
413-
operation.setWaitOperandsSegments(segments);
414-
415-
// In addition to having to set the 'segments', wait also has a list of
416-
// bool attributes whether it is annotated with 'devnum'. We can use
417-
// our knowledge of how much the 'segments' array grew to determine how
418-
// many we need to add.
419-
llvm::SmallVector<bool> hasDevNums;
420-
if (operation.getHasWaitDevnumAttr())
421-
for (mlir::Attribute A : operation.getHasWaitDevnumAttr())
422-
hasDevNums.push_back(cast<mlir::BoolAttr>(A).getValue());
423-
424-
hasDevNums.insert(hasDevNums.end(), segments.size() - beforeSegmentSize,
425-
clause.hasDevNumExpr());
426-
427-
operation.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasDevNums));
289+
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
290+
values, lastDeviceTypeValues);
428291
}
429292
} else {
430293
// TODO: When we've implemented this for everything, switch this to an
@@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
589452
if (s.hasDevNumExpr())
590453
waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
591454

592-
for (Expr *QueueExpr : s.getQueueIdExprs())
455+
for (Expr *QueueExpr : s.getQueueIdExprs())
593456
waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
594457
}
595458

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,31 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
14081408
static mlir::acc::Construct getConstructId() {
14091409
return mlir::acc::Construct::acc_construct_parallel;
14101410
}
1411+
/// Add a value to 'num_workers' with the current list of device types.
1412+
void addNumWorkersOperand(MLIRContext *, mlir::Value,
1413+
llvm::ArrayRef<DeviceType>);
1414+
/// Add a value to 'vector_length' with the current list of device types.
1415+
void addVectorLengthOperand(MLIRContext *, mlir::Value,
1416+
llvm::ArrayRef<DeviceType>);
1417+
/// Add an entry to the 'async-only' attribute (clause spelled without
1418+
/// arguments)for each of the additional device types (or a none if it is
1419+
/// empty).
1420+
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1421+
/// Add a value to the 'async' with the current list of device types.
1422+
void addAsyncOperand(MLIRContext *, mlir::Value,
1423+
llvm::ArrayRef<DeviceType>);
1424+
/// Add an array-like entry to the 'num_gangs' with the current list of
1425+
/// device types.
1426+
void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
1427+
llvm::ArrayRef<DeviceType>);
1428+
/// Add an entry to the 'wait-only' attribute (clause spelled without
1429+
/// arguments)for each of the additional device types (or a none if it is
1430+
/// empty).
1431+
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1432+
/// Add an array-like entry to the 'wait' with the current list of device
1433+
/// types.
1434+
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
1435+
llvm::ArrayRef<DeviceType>);
14111436
}];
14121437

14131438
let assemblyFormat = [{
@@ -1535,6 +1560,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
15351560
static mlir::acc::Construct getConstructId() {
15361561
return mlir::acc::Construct::acc_construct_serial;
15371562
}
1563+
/// Add an entry to the 'async-only' attribute (clause spelled without
1564+
/// arguments) for each of the additional device types (or a none if it is
1565+
/// empty).
1566+
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1567+
/// Add a value to the 'async' with the current list of device types.
1568+
void addAsyncOperand(MLIRContext *, mlir::Value,
1569+
llvm::ArrayRef<DeviceType>);
1570+
/// Add an entry to the 'wait-only' attribute (clause spelled without
1571+
/// arguments) for each of the additional device types (or a none if it is
1572+
/// empty).
1573+
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1574+
/// Add an array-like entry to the 'wait' with the current list of device
1575+
/// types.
1576+
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
1577+
llvm::ArrayRef<DeviceType>);
15381578
}];
15391579

15401580
let assemblyFormat = [{
@@ -1679,6 +1719,31 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
16791719
static mlir::acc::Construct getConstructId() {
16801720
return mlir::acc::Construct::acc_construct_kernels;
16811721
}
1722+
/// Add a value to 'num_workers' with the current list of device types.
1723+
void addNumWorkersOperand(MLIRContext *, mlir::Value,
1724+
llvm::ArrayRef<DeviceType>);
1725+
/// Add a value to 'vector_length' with the current list of device types.
1726+
void addVectorLengthOperand(MLIRContext *, mlir::Value,
1727+
llvm::ArrayRef<DeviceType>);
1728+
/// Add an entry to the 'async-only' attribute (clause spelled without
1729+
/// arguments) for each of the additional device types (or a none if it is
1730+
/// empty).
1731+
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1732+
/// Add a value to the 'async' with the current list of device types.
1733+
void addAsyncOperand(MLIRContext *, mlir::Value,
1734+
llvm::ArrayRef<DeviceType>);
1735+
/// Add an array-like entry to the 'num_gangs' with the current list of
1736+
/// device types.
1737+
void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
1738+
llvm::ArrayRef<DeviceType>);
1739+
/// Add an entry to the 'wait-only' attribute (clause spelled without
1740+
/// arguments) for each of the additional device types (or a none if it is
1741+
/// empty).
1742+
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1743+
/// Add an array-like entry to the 'wait' with the current list of device
1744+
/// types.
1745+
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
1746+
llvm::ArrayRef<DeviceType>);
16821747
}];
16831748

16841749
let assemblyFormat = [{
@@ -1785,6 +1850,21 @@ def OpenACC_DataOp : OpenACC_Op<"data",
17851850
/// Return the wait devnum value clause for the given device_type if
17861851
/// present.
17871852
mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
1853+
/// Add an entry to the 'async-only' attribute (clause spelled without
1854+
/// arguments) for each of the additional device types (or a none if it is
1855+
/// empty).
1856+
void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1857+
/// Add a value to the 'async' with the current list of device types.
1858+
void addAsyncOperand(MLIRContext *, mlir::Value,
1859+
llvm::ArrayRef<DeviceType>);
1860+
/// Add an entry to the 'wait-only' attribute (clause spelled without
1861+
/// arguments) for each of the additional device types (or a none if it is
1862+
/// empty).
1863+
void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
1864+
/// Add an array-like entry to the 'wait' with the current list of device
1865+
/// types.
1866+
void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
1867+
llvm::ArrayRef<DeviceType>);
17881868
}];
17891869

17901870
let assemblyFormat = [{

0 commit comments

Comments
 (0)