Skip to content

Commit aecefec

Browse files
author
Simon Camphausen
committed
[WIP][mlir][EmitC] Model lvalues as a type in EmitC
1 parent 2c703ed commit aecefec

File tree

17 files changed

+391
-162
lines changed

17 files changed

+391
-162
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,22 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
836836
let assemblyFormat = "operands attr-dict `:` type(operands)";
837837
}
838838

839+
def EmitC_LValueToRValueOp : EmitC_Op<"lvalue_to_rvalue", [
840+
TypesMatchWith<"result type matches value type of 'operand'",
841+
"operand", "result",
842+
"::llvm::cast<LValueType>($_self).getValue()">
843+
]> {
844+
let summary = "lvalue to rvalue conversion operation";
845+
let description = [{}];
846+
847+
let arguments = (ins EmitC_LValueType:$operand);
848+
let results = (outs AnyType:$result);
849+
850+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
851+
852+
let hasVerifier = 1;
853+
}
854+
839855
def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
840856
let summary = "Multiplication operation";
841857
let description = [{
@@ -1011,7 +1027,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10111027
}];
10121028

10131029
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
1014-
let results = (outs AnyType);
1030+
let results = (outs EmitC_LValueType);
10151031

10161032
let hasVerifier = 1;
10171033
}
@@ -1070,7 +1086,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
10701086
```
10711087
}];
10721088

1073-
let arguments = (ins AnyType:$var, AnyType:$value);
1089+
let arguments = (ins EmitC_LValueType:$var, AnyType:$value);
10741090
let results = (outs);
10751091

10761092
let hasVerifier = 1;
@@ -1158,7 +1174,7 @@ def EmitC_IfOp : EmitC_Op<"if",
11581174
def EmitC_SubscriptOp : EmitC_Op<"subscript",
11591175
[TypesMatchWith<"result type matches element type of 'array'",
11601176
"array", "result",
1161-
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
1177+
"LValueType::get(::llvm::cast<ArrayType>($_self).getElementType())">]> {
11621178
let summary = "Array subscript operation";
11631179
let description = [{
11641180
With the `subscript` operation the subscript operator `[]` can be applied
@@ -1174,7 +1190,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
11741190
}];
11751191
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
11761192
Variadic<IntegerIndexOrOpaqueType>:$indices);
1177-
let results = (outs AnyType:$result);
1193+
let results = (outs EmitC_LValueType:$result);
11781194

11791195
let builders = [
11801196
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,23 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
7474
let hasCustomAssemblyFormat = 1;
7575
}
7676

77+
def EmitC_LValueType : EmitC_Type<"LValue", "lvalue"> {
78+
let summary = "EmitC lvalue type";
79+
80+
let description = [{
81+
Values of this type can be assigned to and their address can be taken.
82+
}];
83+
84+
let parameters = (ins "Type":$value);
85+
let builders = [
86+
TypeBuilderWithInferredContext<(ins "Type":$value), [{
87+
return $_get(value.getContext(), value);
88+
}]>
89+
];
90+
let assemblyFormat = "`<` qualified($value) `>`";
91+
let genVerifyDecl = 1;
92+
}
93+
7794
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
7895
let summary = "EmitC opaque type";
7996

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,
6363

6464
for (OpResult result : op.getResults()) {
6565
Type resultType = result.getType();
66+
Type varType = emitc::LValueType::get(resultType);
6667
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
6768
emitc::VariableOp var =
68-
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69+
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
6970
resultVariables.push_back(var);
7071
}
7172

@@ -76,57 +77,98 @@ static SmallVector<Value> createVariablesForResults(T op,
7677
// the current insertion point of given rewriter.
7778
static void assignValues(ValueRange values, SmallVector<Value> &variables,
7879
PatternRewriter &rewriter, Location loc) {
79-
for (auto [value, var] : llvm::zip(values, variables))
80-
rewriter.create<emitc::AssignOp>(loc, var, value);
80+
for (auto [value, var] : llvm::zip(values, variables)) {
81+
assert(isa<emitc::LValueType>(var.getType()) &&
82+
"expected var to be an lvalue type");
83+
assert(!isa<emitc::LValueType>(value.getType()) &&
84+
"expected value to not be an lvalue type");
85+
auto assign = rewriter.create<emitc::AssignOp>(loc, var, value);
86+
87+
// TODO: Make sure this is safe, as this moves operations with memory
88+
// effects.
89+
if (auto op = dyn_cast_if_present<emitc::LValueToRValueOp>(
90+
value.getDefiningOp())) {
91+
rewriter.moveOpBefore(op, assign);
92+
}
93+
}
8194
}
8295

83-
static void lowerYield(SmallVector<Value> &resultVariables,
84-
PatternRewriter &rewriter, scf::YieldOp yield) {
96+
static void lowerYield(SmallVector<Value> &variables, PatternRewriter &rewriter,
97+
scf::YieldOp yield) {
8598
Location loc = yield.getLoc();
8699
ValueRange operands = yield.getOperands();
87100

88101
OpBuilder::InsertionGuard guard(rewriter);
89102
rewriter.setInsertionPoint(yield);
90103

91-
assignValues(operands, resultVariables, rewriter, loc);
104+
assignValues(operands, variables, rewriter, loc);
92105

93106
rewriter.create<emitc::YieldOp>(loc);
94107
rewriter.eraseOp(yield);
95108
}
96109

110+
static void replaceUsers(PatternRewriter &rewriter,
111+
SmallVector<Value> fromValues,
112+
SmallVector<Value> toValues) {
113+
OpBuilder::InsertionGuard guard(rewriter);
114+
for (auto [from, to] : llvm::zip(fromValues, toValues)) {
115+
assert(from.getType() == cast<emitc::LValueType>(to.getType()).getValue() &&
116+
"expected types to match");
117+
118+
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
119+
Operation *op = operand.getOwner();
120+
// Skip yield ops, as these get rewritten anyways.
121+
if (isa<scf::YieldOp>(op)) {
122+
continue;
123+
}
124+
Location loc = op->getLoc();
125+
126+
rewriter.setInsertionPoint(op);
127+
Value rValue =
128+
rewriter.create<emitc::LValueToRValueOp>(loc, from.getType(), to);
129+
operand.set(rValue);
130+
}
131+
}
132+
}
133+
97134
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
98135
PatternRewriter &rewriter) const {
99136
Location loc = forOp.getLoc();
100137

101-
// Create an emitc::variable op for each result. These variables will be
102-
// assigned to by emitc::assign ops within the loop body.
103-
SmallVector<Value> resultVariables =
104-
createVariablesForResults(forOp, rewriter);
105-
SmallVector<Value> iterArgsVariables =
106-
createVariablesForResults(forOp, rewriter);
138+
// Create an emitc::variable op for each result. These variables will be used
139+
// for the results of the operations as well as the iter_args. They are
140+
// assigned to by emitc::assign ops before the loop and at the end of the loop
141+
// body.
142+
SmallVector<Value> variables = createVariablesForResults(forOp, rewriter);
107143

108-
assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
144+
// Assign initial values to the iter arg variables.
145+
assignValues(forOp.getInits(), variables, rewriter, loc);
109146

110-
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
111-
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
147+
// Replace users of the iter args with variables.
148+
SmallVector<Value> iterArgs;
149+
for (BlockArgument arg : forOp.getRegionIterArgs()) {
150+
iterArgs.push_back(arg);
151+
}
112152

113-
Block *loweredBody = loweredFor.getBody();
153+
replaceUsers(rewriter, iterArgs, variables);
114154

115-
// Erase the auto-generated terminator for the lowered for op.
116-
rewriter.eraseOp(loweredBody->getTerminator());
155+
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
156+
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
157+
rewriter.eraseBlock(loweredFor.getBody());
117158

118-
SmallVector<Value> replacingValues;
119-
replacingValues.push_back(loweredFor.getInductionVar());
120-
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
159+
rewriter.inlineRegionBefore(forOp.getRegion(), loweredFor.getRegion(),
160+
loweredFor.getRegion().end());
161+
Operation *terminator = loweredFor.getRegion().back().getTerminator();
162+
lowerYield(variables, rewriter, cast<scf::YieldOp>(terminator));
121163

122-
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
123-
lowerYield(iterArgsVariables, rewriter,
124-
cast<scf::YieldOp>(loweredBody->getTerminator()));
164+
// Erase block arguments for iter_args.
165+
loweredFor.getRegion().back().eraseArguments(1, variables.size());
125166

126-
// Copy iterArgs into results after the for loop.
127-
assignValues(iterArgsVariables, resultVariables, rewriter, loc);
167+
// Replace all users of the results with lazily created lvalue-to-rvalue
168+
// ops.
169+
replaceUsers(rewriter, forOp.getResults(), variables);
128170

129-
rewriter.replaceOp(forOp, resultVariables);
171+
rewriter.eraseOp(forOp);
130172
return success();
131173
}
132174

@@ -167,7 +209,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
167209

168210
bool hasElseBlock = !elseRegion.empty();
169211

170-
auto loweredIf =
212+
emitc::IfOp loweredIf =
171213
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
172214

173215
Region &loweredThenRegion = loweredIf.getThenRegion();
@@ -178,7 +220,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
178220
lowerRegion(elseRegion, loweredElseRegion);
179221
}
180222

181-
rewriter.replaceOp(ifOp, resultVariables);
223+
// Replace all users of the results with lazily created lvalue-to-rvalue
224+
// ops.
225+
replaceUsers(rewriter, ifOp.getResults(), resultVariables);
226+
227+
rewriter.eraseOp(ifOp);
182228
return success();
183229
}
184230

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
6868
<< "string attributes are not supported, use #emitc.opaque instead";
6969

7070
Type resultType = op->getResult(0).getType();
71+
if (auto lType = dyn_cast<LValueType>(resultType))
72+
resultType = lType.getValue();
7173
Type attrType = cast<TypedAttr>(value).getType();
7274

7375
if (resultType != attrType)
@@ -131,18 +133,21 @@ LogicalResult ApplyOp::verify() {
131133
/// assigned-to variable type.
132134
LogicalResult emitc::AssignOp::verify() {
133135
Value variable = getVar();
134-
Operation *variableDef = variable.getDefiningOp();
135-
if (!variableDef ||
136-
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
136+
137+
if (!variable.getDefiningOp())
138+
return emitOpError() << "cannot assign to block argument";
139+
if (!llvm::isa<emitc::LValueType>(variable.getType()))
137140
return emitOpError() << "requires first operand (" << variable
138-
<< ") to be a Variable or subscript";
139-
140-
Value value = getValue();
141-
if (variable.getType() != value.getType())
142-
return emitOpError() << "requires value's type (" << value.getType()
143-
<< ") to match variable's type (" << variable.getType()
144-
<< ")";
145-
if (isa<ArrayType>(variable.getType()))
141+
<< ") to be an lvalue";
142+
143+
Type valueType = getValue().getType();
144+
Type variableType = variable.getType().cast<emitc::LValueType>().getValue();
145+
if (variableType != valueType)
146+
return emitOpError() << "requires value's type (" << valueType
147+
<< ") to match variable's type (" << variableType
148+
<< ")\n variable: " << variable
149+
<< "\n value: " << getValue() << "\n";
150+
if (isa<ArrayType>(variableType))
146151
return emitOpError() << "cannot assign to array type";
147152
return success();
148153
}
@@ -698,6 +703,47 @@ LogicalResult emitc::LiteralOp::verify() {
698703
return emitOpError() << "value must not be empty";
699704
return success();
700705
}
706+
707+
//===----------------------------------------------------------------------===//
708+
// LValueToRValueOp
709+
//===----------------------------------------------------------------------===//
710+
711+
LogicalResult emitc::LValueToRValueOp::verify() {
712+
Type operandType = getOperand().getType();
713+
Type resultType = getResult().getType();
714+
if (!llvm::isa<emitc::LValueType>(operandType))
715+
return emitOpError("operand must be a lvalue");
716+
if (llvm::cast<emitc::LValueType>(operandType).getValue() != resultType)
717+
return emitOpError("types must match");
718+
719+
Value result = getResult();
720+
if (!result.hasOneUse()) {
721+
int numUses = std::distance(result.use_begin(), result.use_end());
722+
return emitOpError("must have exactly one use, but got ") << numUses;
723+
}
724+
725+
Block *block = result.getParentBlock();
726+
727+
Operation *user = *result.getUsers().begin();
728+
Block *userBlock = user->getBlock();
729+
730+
if (block != userBlock) {
731+
return emitOpError("user must be in the same block");
732+
}
733+
734+
// for (auto it = block.begin(), e = std::prev(block.end()); it != e; it++) {
735+
// if (*it == this)
736+
// }
737+
738+
// TODO: To model this op correctly as a memory read of the lvalue, we
739+
// should additionally ensure that the single use of the op follows immediatly
740+
// on this definition. Alternativly we could alter emitc ops to implicitly
741+
// support lvalues. This would make it harder to do partial conversions and
742+
// mix dialects though.
743+
744+
return success();
745+
}
746+
701747
//===----------------------------------------------------------------------===//
702748
// SubOp
703749
//===----------------------------------------------------------------------===//
@@ -851,6 +897,20 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
851897
return emitc::ArrayType::get(*shape, elementType);
852898
}
853899

900+
//===----------------------------------------------------------------------===//
901+
// LValueType
902+
//===----------------------------------------------------------------------===//
903+
904+
LogicalResult mlir::emitc::LValueType::verify(
905+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
906+
mlir::Type value) {
907+
if (llvm::isa<emitc::LValueType>(value)) {
908+
return emitError()
909+
<< "!emitc.lvalue type cannot be nested inside another type";
910+
}
911+
return success();
912+
}
913+
854914
//===----------------------------------------------------------------------===//
855915
// OpaqueType
856916
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)