Skip to content

Commit 2e6cc79

Browse files
authored
[MLIR][NVVM] Migrate CpAsyncOp to intrinsics (llvm#123789)
Intrinsics are available for the 'cpSize' variants also. So, this patch migrates the Op to lower to the intrinsics for all cases. * Update the existing tests to check the lowering to intrinsics. * Add newer cp_async_zfill tests to verify the lowering for the 'cpSize' variants. * Tidy-up CHECK lines in cp_async() function in nvvmir.mlir (NFC) PTX spec link: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async Signed-off-by: Durgadoss R <[email protected]>
1 parent cad6bba commit 2e6cc79

File tree

5 files changed

+70
-52
lines changed

5 files changed

+70
-52
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/OpDefinition.h"
2222
#include "mlir/Interfaces/InferIntRangeInterface.h"
2323
#include "mlir/Interfaces/SideEffectInterfaces.h"
24+
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
2425
#include "llvm/IR/IntrinsicsNVPTX.h"
2526

2627
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -849,55 +849,24 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
849849

850850
def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
851851

852-
def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
852+
def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
853853
Arguments<(ins LLVM_PointerShared:$dst,
854854
LLVM_PointerGlobal:$src,
855855
I32Attr:$size,
856856
LoadCacheModifierAttr:$modifier,
857857
Optional<LLVM_Type>:$cpSize)> {
858-
string llvmBuilder = [{
859-
llvm::Intrinsic::ID id;
860-
switch ($size) {
861-
case 4:
862-
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
863-
break;
864-
case 8:
865-
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
866-
break;
867-
case 16:
868-
if($modifier == NVVM::LoadCacheModifierKind::CG)
869-
id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16;
870-
else if($modifier == NVVM::LoadCacheModifierKind::CA)
871-
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
872-
else
873-
llvm_unreachable("unsupported cache modifier");
874-
break;
875-
default:
876-
llvm_unreachable("unsupported async copy size");
877-
}
878-
createIntrinsicCall(builder, id, {$dst, $src});
879-
}];
880858
let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)";
881859
let hasVerifier = 1;
882860
let extraClassDeclaration = [{
883-
bool hasIntrinsic() { if(getCpSize()) return false; return true; }
884-
885-
void getAsmValues(RewriterBase &rewriter,
886-
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) {
887-
asmValues.push_back({getDst(), PTXRegisterMod::Read});
888-
asmValues.push_back({getSrc(), PTXRegisterMod::Read});
889-
asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read});
890-
asmValues.push_back({getCpSize(), PTXRegisterMod::Read});
891-
}
861+
static llvm::Intrinsic::ID
862+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
863+
llvm::SmallVector<llvm::Value *> &args);
892864
}];
893-
let extraClassDefinition = [{
894-
std::string $cppClass::getPtx() {
895-
if(getModifier() == NVVM::LoadCacheModifierKind::CG)
896-
return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n");
897-
if(getModifier() == NVVM::LoadCacheModifierKind::CA)
898-
return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n");
899-
llvm_unreachable("unsupported cache modifier");
900-
}
865+
string llvmBuilder = [{
866+
llvm::SmallVector<llvm::Value *> translatedOperands;
867+
auto id = NVVM::CpAsyncOp::getIntrinsicIDAndArgs(
868+
*op, moduleTranslation, translatedOperands);
869+
createIntrinsicCall(builder, id, translatedOperands);
901870
}];
902871
}
903872

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,44 @@ LogicalResult NVVM::BarrierOp::verify() {
11101110
return success();
11111111
}
11121112

1113+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1114+
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1115+
1116+
#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1117+
has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1118+
1119+
llvm::Intrinsic::ID
1120+
CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1121+
llvm::SmallVector<llvm::Value *> &args) {
1122+
llvm::Intrinsic::ID id;
1123+
1124+
auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1125+
bool hasCpSize = cpAsyncOp.getCpSize() ? true : false;
1126+
switch (cpAsyncOp.getSize()) {
1127+
case 4:
1128+
id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1129+
break;
1130+
case 8:
1131+
id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1132+
break;
1133+
case 16:
1134+
id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1135+
? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1136+
: GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1137+
break;
1138+
default:
1139+
llvm_unreachable("Invalid copy size in CpAsyncOp.");
1140+
}
1141+
1142+
// Fill the Intrinsic Args
1143+
args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1144+
args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1145+
if (hasCpSize)
1146+
args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1147+
1148+
return id;
1149+
}
1150+
11131151
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
11141152
bool isIm2Col) {
11151153
switch (tensorDims) {

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,9 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
7474

7575
// CHECK-LABEL: @async_cp_zfill
7676
func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
77-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
78-
// CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A",
79-
// CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
77+
// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32
8078
nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
81-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
82-
// CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A",
83-
// CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
79+
// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 4, cache = ca, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32
8480
nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
8581
return
8682
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,21 +488,35 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
488488

489489
// CHECK-LABEL: @cp_async
490490
llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
491-
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
491+
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
492492
nvvm.cp.async.shared.global %arg0, %arg1, 4, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
493-
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
493+
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
494494
nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
495-
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
495+
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
496496
nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
497-
// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
497+
// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
498498
nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
499-
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
499+
500+
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
500501
nvvm.cp.async.commit.group
501-
// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
502+
// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
502503
nvvm.cp.async.wait.group 0
503504
llvm.return
504505
}
505506

507+
// CHECK-LABEL: @async_cp_zfill
508+
llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
509+
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
510+
nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
511+
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
512+
nvvm.cp.async.shared.global %dst, %src, 8, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
513+
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
514+
nvvm.cp.async.shared.global %dst, %src, 16, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
515+
// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
516+
nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
517+
llvm.return
518+
}
519+
506520
// CHECK-LABEL: @cp_async_mbarrier_arrive
507521
llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
508522
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})

0 commit comments

Comments
 (0)