Skip to content

Commit 27158ed

Browse files
authored
[MLIR][SPIRV] Update cast from IntN to Bool (#113329)
This PR updates the cast to bool from IntN to treat any non-zero value as TRUE. This makes the cast more resilient to non-generic (i.e. "non 1") TRUE values. Signed-off-by: Dmitriy Smirnov <[email protected]>
1 parent 684c26c commit 27158ed

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165165
if (srcInt.getType().isInteger(1))
166166
return srcInt;
167167

168-
auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
169-
return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
168+
auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
169+
return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
170170
}
171171

172172
//===----------------------------------------------------------------------===//

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
7676
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
7777
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
7878
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
79-
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
80-
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
79+
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
80+
// CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
8181
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>>
8282
// CHECK: return %[[BOOL]]
8383
return %0: i1
@@ -234,8 +234,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
234234
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
235235
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
236236
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
237-
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
238-
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
237+
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
238+
// CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
239239
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
240240
// CHECK: return %[[BOOL]]
241241
return %0: i1

0 commit comments

Comments
 (0)