Skip to content

Commit 39ed304

Browse files
committed
Allow fixed vector operand for LLVM_AtomicRMWOp
This PR fixes `LLVM_AtomicRMWOp` allowed semantics and verifier logic to enable building of `LLVM_AtomicRMWOp` with fixed vectors of compatible fp values as operands for fp rmw operation. See also: https://llvm.org/docs/LangRef.html#id231 Signed-off-by: Ilya Veselov <[email protected]>
1 parent 19c6958 commit 39ed304

File tree

5 files changed

+35
-7
lines changed

5 files changed

+35
-7
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
17371737
// Atomic operations.
17381738
//
17391739

1740-
def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
1740+
def LLVM_AtomicRMWType
1741+
: AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_AnyFixedVector]>;
17411742

17421743
def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
17431744
TypesMatchWith<"result #0 and operand #1 have the same type",

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3010,8 +3010,16 @@ LogicalResult AtomicRMWOp::verify() {
30103010
auto valType = getVal().getType();
30113011
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
30123012
getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
3013-
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
3013+
if (isCompatibleVectorType(valType)) {
3014+
if (isScalableVectorType(valType))
3015+
return emitOpError("expected LLVM IR fixed vector type");
3016+
Type elemType = getVectorElementType(valType);
3017+
if (!isCompatibleFloatingPointType(elemType))
3018+
return emitOpError(
3019+
"expected LLVM IR floating point type for vector element");
3020+
} else if (!isCompatibleFloatingPointType(valType)) {
30143021
return emitOpError("expected LLVM IR floating point type");
3022+
}
30153023
} else if (getBinOp() == AtomicBinOp::xchg) {
30163024
DataLayout dataLayout = DataLayout::closest(*this);
30173025
if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,14 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {
643643

644644
// -----
645645

646+
func.func @atomicrmw_unexpected_vector_element(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
647+
// expected-error@+1 {{expected LLVM IR floating point type for vector element}}
648+
%0 = llvm.atomicrmw fadd %ptr, %i32_vec unordered : !llvm.ptr, vector<3xi32>
649+
llvm.return
650+
}
651+
652+
// -----
653+
646654
func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
647655
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
648656
%0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) {
420420
}
421421

422422
// CHECK-LABEL: @atomicrmw
423-
func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) {
423+
func.func @atomicrmw(%f32_ptr : !llvm.ptr, %f32 : f32, %f32_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>) {
424424
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32
425-
%0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32
425+
%0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32
426426
// CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
427-
%1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
427+
%1 = llvm.atomicrmw volatile fsub %f32_ptr, %f32 syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
428+
// CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16>
429+
%2 = llvm.atomicrmw fmin %f32_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
428430
llvm.return
429431
}
430432

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,8 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i
14961496
// CHECK-LABEL: @atomicrmw
14971497
llvm.func @atomicrmw(
14981498
%f32_ptr : !llvm.ptr, %f32 : f32,
1499-
%i32_ptr : !llvm.ptr, %i32 : i32) {
1499+
%i32_ptr : !llvm.ptr, %i32 : i32,
1500+
%f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>) {
15001501
// CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic
15011502
%0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32
15021503
// CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic
@@ -1535,11 +1536,19 @@ llvm.func @atomicrmw(
15351536
%17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32
15361537
// CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic
15371538
%18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32
1539+
// CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic
1540+
%19 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
1541+
// CHECK: atomicrmw fsub ptr %{{.*}}, <2 x half> %{{.*}} monotonic
1542+
%20 = llvm.atomicrmw fsub %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
1543+
// CHECK: atomicrmw fmax ptr %{{.*}}, <2 x half> %{{.*}} monotonic
1544+
%21 = llvm.atomicrmw fmax %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
1545+
// CHECK: atomicrmw fmin ptr %{{.*}}, <2 x half> %{{.*}} monotonic
1546+
%22 = llvm.atomicrmw fmin %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
15381547

15391548
// CHECK: atomicrmw volatile
15401549
// CHECK-SAME: syncscope("singlethread")
15411550
// CHECK-SAME: align 8
1542-
%19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
1551+
%23 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
15431552
llvm.return
15441553
}
15451554

0 commit comments

Comments
 (0)