@@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,
63
63
64
64
for (OpResult result : op.getResults ()) {
65
65
Type resultType = result.getType ();
66
+ Type varType = emitc::LValueType::get (resultType);
66
67
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
67
68
emitc::VariableOp var =
68
- rewriter.create <emitc::VariableOp>(loc, resultType , noInit);
69
+ rewriter.create <emitc::VariableOp>(loc, varType , noInit);
69
70
resultVariables.push_back (var);
70
71
}
71
72
@@ -76,57 +77,98 @@ static SmallVector<Value> createVariablesForResults(T op,
76
77
// the current insertion point of given rewriter.
77
78
static void assignValues (ValueRange values, SmallVector<Value> &variables,
78
79
PatternRewriter &rewriter, Location loc) {
79
- for (auto [value, var] : llvm::zip (values, variables))
80
- rewriter.create <emitc::AssignOp>(loc, var, value);
80
+ for (auto [value, var] : llvm::zip (values, variables)) {
81
+ assert (isa<emitc::LValueType>(var.getType ()) &&
82
+ " expected var to be an lvalue type" );
83
+ assert (!isa<emitc::LValueType>(value.getType ()) &&
84
+ " expected value to not be an lvalue type" );
85
+ auto assign = rewriter.create <emitc::AssignOp>(loc, var, value);
86
+
87
+ // TODO: Make sure this is safe, as this moves operations with memory
88
+ // effects.
89
+ if (auto op = dyn_cast_if_present<emitc::LValueToRValueOp>(
90
+ value.getDefiningOp ())) {
91
+ rewriter.moveOpBefore (op, assign);
92
+ }
93
+ }
81
94
}
82
95
83
- static void lowerYield (SmallVector<Value> &resultVariables ,
84
- PatternRewriter &rewriter, scf::YieldOp yield) {
96
+ static void lowerYield (SmallVector<Value> &variables, PatternRewriter &rewriter ,
97
+ scf::YieldOp yield) {
85
98
Location loc = yield.getLoc ();
86
99
ValueRange operands = yield.getOperands ();
87
100
88
101
OpBuilder::InsertionGuard guard (rewriter);
89
102
rewriter.setInsertionPoint (yield);
90
103
91
- assignValues (operands, resultVariables , rewriter, loc);
104
+ assignValues (operands, variables , rewriter, loc);
92
105
93
106
rewriter.create <emitc::YieldOp>(loc);
94
107
rewriter.eraseOp (yield);
95
108
}
96
109
110
+ static void replaceUsers (PatternRewriter &rewriter,
111
+ SmallVector<Value> fromValues,
112
+ SmallVector<Value> toValues) {
113
+ OpBuilder::InsertionGuard guard (rewriter);
114
+ for (auto [from, to] : llvm::zip (fromValues, toValues)) {
115
+ assert (from.getType () == cast<emitc::LValueType>(to.getType ()).getValue () &&
116
+ " expected types to match" );
117
+
118
+ for (OpOperand &operand : llvm::make_early_inc_range (from.getUses ())) {
119
+ Operation *op = operand.getOwner ();
120
+ // Skip yield ops, as these get rewritten anyways.
121
+ if (isa<scf::YieldOp>(op)) {
122
+ continue ;
123
+ }
124
+ Location loc = op->getLoc ();
125
+
126
+ rewriter.setInsertionPoint (op);
127
+ Value rValue =
128
+ rewriter.create <emitc::LValueToRValueOp>(loc, from.getType (), to);
129
+ operand.set (rValue);
130
+ }
131
+ }
132
+ }
133
+
97
134
LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
98
135
PatternRewriter &rewriter) const {
99
136
Location loc = forOp.getLoc ();
100
137
101
- // Create an emitc::variable op for each result. These variables will be
102
- // assigned to by emitc::assign ops within the loop body.
103
- SmallVector<Value> resultVariables =
104
- createVariablesForResults (forOp, rewriter);
105
- SmallVector<Value> iterArgsVariables =
106
- createVariablesForResults (forOp, rewriter);
138
+ // Create an emitc::variable op for each result. These variables will be used
139
+ // for the results of the operations as well as the iter_args. They are
140
+ // assigned to by emitc::assign ops before the loop and at the end of the loop
141
+ // body.
142
+ SmallVector<Value> variables = createVariablesForResults (forOp, rewriter);
107
143
108
- assignValues (forOp.getInits (), iterArgsVariables, rewriter, loc);
144
+ // Assign initial values to the iter arg variables.
145
+ assignValues (forOp.getInits (), variables, rewriter, loc);
109
146
110
- emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
111
- loc, forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ());
147
+ // Replace users of the iter args with variables.
148
+ SmallVector<Value> iterArgs;
149
+ for (BlockArgument arg : forOp.getRegionIterArgs ()) {
150
+ iterArgs.push_back (arg);
151
+ }
112
152
113
- Block *loweredBody = loweredFor. getBody ( );
153
+ replaceUsers (rewriter, iterArgs, variables );
114
154
115
- // Erase the auto-generated terminator for the lowered for op.
116
- rewriter.eraseOp (loweredBody->getTerminator ());
155
+ emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
156
+ loc, forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ());
157
+ rewriter.eraseBlock (loweredFor.getBody ());
117
158
118
- SmallVector<Value> replacingValues;
119
- replacingValues.push_back (loweredFor.getInductionVar ());
120
- replacingValues.append (iterArgsVariables.begin (), iterArgsVariables.end ());
159
+ rewriter.inlineRegionBefore (forOp.getRegion (), loweredFor.getRegion (),
160
+ loweredFor.getRegion ().end ());
161
+ Operation *terminator = loweredFor.getRegion ().back ().getTerminator ();
162
+ lowerYield (variables, rewriter, cast<scf::YieldOp>(terminator));
121
163
122
- rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
123
- lowerYield (iterArgsVariables, rewriter,
124
- cast<scf::YieldOp>(loweredBody->getTerminator ()));
164
+ // Erase block arguments for iter_args.
165
+ loweredFor.getRegion ().back ().eraseArguments (1 , variables.size ());
125
166
126
- // Copy iterArgs into results after the for loop.
127
- assignValues (iterArgsVariables, resultVariables, rewriter, loc);
167
+ // Replace all users of the results with lazily created lvalue-to-rvalue
168
+ // ops.
169
+ replaceUsers (rewriter, forOp.getResults (), variables);
128
170
129
- rewriter.replaceOp (forOp, resultVariables );
171
+ rewriter.eraseOp (forOp);
130
172
return success ();
131
173
}
132
174
@@ -167,7 +209,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
167
209
168
210
bool hasElseBlock = !elseRegion.empty ();
169
211
170
- auto loweredIf =
212
+ emitc::IfOp loweredIf =
171
213
rewriter.create <emitc::IfOp>(loc, ifOp.getCondition (), false , false );
172
214
173
215
Region &loweredThenRegion = loweredIf.getThenRegion ();
@@ -178,7 +220,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
178
220
lowerRegion (elseRegion, loweredElseRegion);
179
221
}
180
222
181
- rewriter.replaceOp (ifOp, resultVariables);
223
+ // Replace all users of the results with lazily created lvalue-to-rvalue
224
+ // ops.
225
+ replaceUsers (rewriter, ifOp.getResults (), resultVariables);
226
+
227
+ rewriter.eraseOp (ifOp);
182
228
return success ();
183
229
}
184
230
0 commit comments