Skip to content

Commit fdb0108

Browse files
committed
Allow 16 bit floating point operand for LLVM_AtomicRMWOp
As far as AMDGPU target supports vectorization for atomic_rmw operation, allow construction of LLVM_AtomicRMWOp with 16 bit floating point values. This patch enables building of LLVM_AtomicRMWOp with fixed vectors of 16 bit fp values as operands. See also: #94845, #95393, #95394 Signed-off-by: Ilya Veselov <[email protected]>
1 parent e96f778 commit fdb0108

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
17371737
// Atomic operations.
17381738
//
17391739

1740-
def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
1740+
def LLVM_AtomicRMWType
1741+
: AnyTypeOf<[LLVM_AnyPointer, AnySignlessInteger, LLVM_ScalarOrVectorOf<LLVM_AnyFloat>]>;
17411742

17421743
def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
17431744
TypesMatchWith<"result #0 and operand #1 have the same type",

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3008,9 +3008,20 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
30083008

30093009
LogicalResult AtomicRMWOp::verify() {
30103010
auto valType = getVal().getType();
3011-
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3012-
getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
3013-
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
3011+
if (getBinOp() == AtomicBinOp::fadd && isCompatibleVectorType(valType)) {
3012+
// Currently, only fadd operation supports fixed vector operands.
3013+
if (isScalableVectorType(valType))
3014+
return emitOpError("expected LLVM IR fixed vector type");
3015+
Type elemType = getVectorElementType(valType);
3016+
// Only 16 bit floating point elements are supported for now.
3017+
if (!(isCompatibleFloatingPointType(elemType) &&
3018+
elemType.getIntOrFloatBitWidth() == 16))
3019+
return emitOpError("unexpected LLVM IR type for vector element");
3020+
} else if (getBinOp() == AtomicBinOp::fadd ||
3021+
getBinOp() == AtomicBinOp::fsub ||
3022+
getBinOp() == AtomicBinOp::fmin ||
3023+
getBinOp() == AtomicBinOp::fmax) {
3024+
if (!isCompatibleFloatingPointType(valType))
30143025
return emitOpError("expected LLVM IR floating point type");
30153026
} else if (getBinOp() == AtomicBinOp::xchg) {
30163027
DataLayout dataLayout = DataLayout::closest(*this);

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,22 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {
643643

644644
// -----
645645

646+
func.func @atomicrmw_unexpected_scalable_vector(%i32_ptr : !llvm.ptr, %f16_vec : vector<[2]xf16>) {
647+
// expected-error@+1 {{expected LLVM IR fixed vector type}}
648+
%0 = llvm.atomicrmw fadd %i32_ptr, %f16_vec unordered : !llvm.ptr, vector<[2]xf16>
649+
llvm.return
650+
}
651+
652+
// -----
653+
654+
func.func @atomicrmw_unexpected_vector_element(%i32_ptr : !llvm.ptr, %f32_vec : vector<3xf32>) {
655+
// expected-error@+1 {{unexpected LLVM IR type for vector element}}
656+
%0 = llvm.atomicrmw fadd %i32_ptr, %f32_vec unordered : !llvm.ptr, vector<3xf32>
657+
llvm.return
658+
}
659+
660+
// -----
661+
646662
func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
647663
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
648664
%0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1

0 commit comments

Comments
 (0)