Skip to content

Commit 0451fe7

Browse files
author
Frank Laub
authored
Fix roundtrip of affine.if (#16)
1 parent d26bfee commit 0451fe7

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def AffineIfOp : Affine_Op<"if",
322322
}
323323
```
324324
}];
325-
let results = (outs Variadic<AnyMemRef>:$results);
326325
let arguments = (ins Variadic<AnyType>);
326+
let results = (outs Variadic<AnyType>:$results);
327327
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
328328

329329
let skipDefaultBuilders = 1;
@@ -755,8 +755,7 @@ def AffineStoreOp : AffineStoreOpBase<"store", []> {
755755
let hasFolder = 1;
756756
}
757757

758-
def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator]>,
759-
Arguments<(ins Variadic<AnyMemRef>:$values)> {
758+
def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator]> {
760759
let summary = "Yield values to parent operation";
761760
let description = [{
762761
"affine.yield" yields zero or more SSA values from an affine op region and
@@ -770,7 +769,7 @@ def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator]>,
770769
yielded.
771770
```
772771
}];
773-
let arguments = (ins Variadic<AnyMemRef>:$results);
772+
let arguments = (ins Variadic<AnyType>:$results);
774773
let builders = [
775774
OpBuilder<"OpBuilder &builder, OperationState &result",
776775
[{ /* nothing to do */ }]>

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,9 @@ static ParseResult parseAffineIfOp(OpAsmParser &parser,
17491749
parser.getNameLoc(),
17501750
"symbol operand count and integer set symbol count must match");
17511751

1752+
if (parser.parseOptionalArrowTypeList(result.types))
1753+
return failure();
1754+
17521755
// Create the regions for 'then' and 'else'. The latter must be created even
17531756
// if it remains empty for the validity of the operation.
17541757
result.regions.reserve(2);
@@ -1782,6 +1785,7 @@ static void print(OpAsmPrinter &p, AffineIfOp op) {
17821785
p << "affine.if " << conditionAttr;
17831786
printDimAndSymbolList(op.operand_begin(), op.operand_end(),
17841787
conditionAttr.getValue().getNumDims(), p);
1788+
p.printOptionalArrowTypeList(op.getResultTypes());
17851789
p.printRegion(op.thenRegion(),
17861790
/*printEntryBlockArgs=*/false,
17871791
/*printBlockTerminators=*/op.getNumResults());

mlir/test/Dialect/Affine/ops.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
153153

154154
// -----
155155

156-
// CHECK-LABEL: @parallel
156+
// CHECK-LABEL: func @parallel
157157
// CHECK-SAME: (%[[N:.*]]: index)
158158
func @parallel(%N : index) {
159159
// CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10)
@@ -166,3 +166,21 @@ func @parallel(%N : index) {
166166
}
167167
return
168168
}
169+
170+
// -----
171+
172+
// CHECK-LABEL: func @affine_if
173+
func @affine_if() -> f32 {
174+
// CHECK: %[[ZERO:.*]] = constant {{.*}} : f32
175+
%zero = constant 0.0 : f32
176+
// CHECK: %[[OUT:.*]] = affine.if {{.*}}() -> f32 {
177+
%0 = affine.if affine_set<() : ()> () -> f32 {
178+
// CHECK: affine.yield %[[ZERO]] : f32
179+
affine.yield %zero : f32
180+
} else {
181+
// CHECK: affine.yield %[[ZERO]] : f32
182+
affine.yield %zero : f32
183+
}
184+
// CHECK: return %[[OUT]] : f32
185+
return %0 : f32
186+
}

0 commit comments

Comments
 (0)