Skip to content

Commit 864b6f4

Browse files
committed
[NVPTX] Add TMA bulk tensor copy intrinsics
This patch adds NVVM intrinsics and NVPTX codeGen for: * cp.async.bulk.tensor.S2G.1D -> 5D variants, with optional support for cache_hints. * cp.async.bulk.tensor.G2S.1D -> 5D variants, with optional support for multicast and cache_hints. Moreover, the 3D->5D variants also have support for an 'im2col' mode, with its own set of offsets. * The first argument of these intrinsics is an immediate i32-flag. The bit-fields of the flag control enabling optional features like multicast, cache_hints and im2col offsets when applicable. The backend looks through these flag-bits and lowers to the appropriate PTX instruction. * Lit tests are added for all combinations of these intrinsics in cp-async-bulk-tensor-g2s/s2g.ll. * The generated PTX is verified with a 12.3 ptxas executable. TODO: Update documentation for these intrinsics in NVPTX guide. Signed-off-by: Durgadoss R <[email protected]>
1 parent 021b9d0 commit 864b6f4

File tree

7 files changed

+896
-0
lines changed

7 files changed

+896
-0
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,22 @@ class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
552552
[OpType, llvm_i32_ty, llvm_i32_ty]);
553553
}
554554

555+
class NVVM_INTRINSIC_RECORD<string intr> {
556+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
557+
}
558+
559+
class NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_NAME<int dim> {
560+
string intr = "llvm.nvvm.cp.async.bulk.tensor.gmem.to.smem"
561+
# "." # dim # "d";
562+
string record = NVVM_INTRINSIC_RECORD<intr>.record;
563+
}
564+
565+
class NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_NAME<int dim> {
566+
string intr = "llvm.nvvm.cp.async.bulk.tensor.smem.to.gmem"
567+
# "." # dim # "d";
568+
string record = NVVM_INTRINSIC_RECORD<intr>.record;
569+
}
570+
555571
let TargetPrefix = "nvvm" in {
556572
def int_nvvm_prmt : ClangBuiltin<"__nvvm_prmt">,
557573
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
@@ -4822,4 +4838,42 @@ def int_nvvm_setmaxnreg_dec_sync_aligned_u32
48224838
def int_nvvm_exit : ClangBuiltin<"__nvvm_exit">,
48234839
Intrinsic<[], [], [IntrConvergent, IntrInaccessibleMemOnly, IntrNoReturn]>;
48244840

4841+
// -------- llvm.nvvm.cp.async.bulk.tensor.gmem.to.smem
4842+
class NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_INTR<int dim> :
4843+
DefaultAttrsIntrinsic<[],
4844+
!listconcat(
4845+
// flags, dst_smem_ptr, barrier_ptr, tensor_map_ptr
4846+
[llvm_i32_ty, llvm_shared_ptr_ty, llvm_shared_ptr_ty, llvm_ptr_ty],
4847+
!listsplat(llvm_i32_ty, dim), // tensor_dims
4848+
!if(!ge(dim, 3), !listsplat(llvm_i16_ty, !add(dim, -2)), []), // im2col
4849+
[llvm_i16_ty, llvm_i64_ty]), // cta_mask, cache_policy
4850+
[IntrConvergent, IntrArgMemOnly, ImmArg<ArgIndex<0>>,
4851+
WriteOnly<ArgIndex<1>>, ReadOnly<ArgIndex<3>>,
4852+
NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>,
4853+
NoCapture<ArgIndex<3>>],
4854+
NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_NAME<dim>.intr>;
4855+
4856+
foreach dim = [1, 2, 3, 4, 5] in {
4857+
def NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_NAME<dim>.record :
4858+
NVVM_CP_ASYNC_BULK_TENSOR_GMEM_TO_SMEM_INTR<dim>;
4859+
}
4860+
4861+
// -------- llvm.nvvm.cp.async.bulk.tensor.smem.to.gmem
4862+
class NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_INTR<int dim> :
4863+
DefaultAttrsIntrinsic<[],
4864+
!listconcat(
4865+
// flags, src_smem_ptr, tensor_map_ptr
4866+
[llvm_i32_ty, llvm_shared_ptr_ty, llvm_ptr_ty],
4867+
!listsplat(llvm_i32_ty, dim), // tensor_dims
4868+
[llvm_i64_ty]), // cache_policy
4869+
[IntrConvergent, IntrArgMemOnly, ImmArg<ArgIndex<0>>,
4870+
ReadOnly<ArgIndex<1>>, WriteOnly<ArgIndex<2>>,
4871+
NoCapture<ArgIndex<1>>, NoCapture<ArgIndex<2>>],
4872+
NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_NAME<dim>.intr>;
4873+
4874+
foreach dim = [1, 2, 3, 4, 5] in {
4875+
def NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_NAME<dim>.record :
4876+
NVVM_CP_ASYNC_BULK_TENSOR_SMEM_TO_GMEM_INTR<dim>;
4877+
}
4878+
48254879
} // let TargetPrefix = "nvvm"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// This file contains the definitions of the enumerations and flags
11+
/// associated with NVVM Intrinsics.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
16+
#define LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
17+
18+
#include <stdint.h>
19+
20+
namespace llvm {
21+
namespace nvvm {
22+
23+
enum class CpAsyncBulkTensorLoadMode {
24+
TILE = 0,
25+
IM2COL = 1,
26+
};
27+
28+
typedef union {
29+
int V;
30+
struct {
31+
unsigned CacheHint : 1;
32+
unsigned MultiCast : 1;
33+
unsigned LoadMode : 3; // CpAsyncBulkTensorLoadMode
34+
unsigned reserved : 27;
35+
} U;
36+
} CpAsyncBulkTensorFlags;
37+
38+
} // namespace nvvm
39+
} // namespace llvm
40+
#endif // LLVM_SUPPORT_NVVMINTRINSICFLAGS_H

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/IR/GlobalValue.h"
1919
#include "llvm/IR/Instructions.h"
2020
#include "llvm/IR/IntrinsicsNVPTX.h"
21+
#include "llvm/IR/NVVMIntrinsicFlags.h"
2122
#include "llvm/Support/AtomicOrdering.h"
2223
#include "llvm/Support/CommandLine.h"
2324
#include "llvm/Support/Debug.h"
@@ -160,6 +161,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
160161
if (tryIntrinsicChain(N))
161162
return;
162163
break;
164+
case ISD::INTRINSIC_VOID:
165+
if (tryIntrinsicVoid(N))
166+
return;
167+
break;
163168
case NVPTXISD::Tex1DFloatS32:
164169
case NVPTXISD::Tex1DFloatFloat:
165170
case NVPTXISD::Tex1DFloatFloatLevel:
@@ -3861,3 +3866,246 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
38613866
}
38623867
}
38633868
}
3869+
3870+
static size_t GetCpAsyncBulkTensorDimFromIntrinsic(unsigned IID) {
3871+
switch (IID) {
3872+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_1d:
3873+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_1d:
3874+
return 1;
3875+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_2d:
3876+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_2d:
3877+
return 2;
3878+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_3d:
3879+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_3d:
3880+
return 3;
3881+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_4d:
3882+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_4d:
3883+
return 4;
3884+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_5d:
3885+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_5d:
3886+
return 5;
3887+
default:
3888+
llvm_unreachable(
3889+
"Invalid Tensor dim in nvvm_cp_async_bulk_tensor intrinsic");
3890+
}
3891+
}
3892+
3893+
#define CP_ASYNC_BULK_TENSOR_OPCODE(dir, dim, mode, suffix) \
3894+
if (IsShared32) { \
3895+
return NVPTX:: \
3896+
CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix; \
3897+
} else { \
3898+
return NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix; \
3899+
}
3900+
3901+
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode) \
3902+
do { \
3903+
if (IsCacheHint) { \
3904+
CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, _CH); \
3905+
} else { \
3906+
CP_ASYNC_BULK_TENSOR_OPCODE(SMEM_TO_GMEM, dim, mode, ); \
3907+
} \
3908+
} while (0)
3909+
3910+
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode) \
3911+
do { \
3912+
if (IsMultiCast && IsCacheHint) { \
3913+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC_CH); \
3914+
} else if (IsCacheHint) { \
3915+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _CH); \
3916+
} else if (IsMultiCast) { \
3917+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, _MC); \
3918+
} else { \
3919+
CP_ASYNC_BULK_TENSOR_OPCODE(GMEM_TO_SMEM, dim, mode, ); \
3920+
} \
3921+
} while (0)
3922+
3923+
static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
3924+
bool IsCacheHint, bool IsIm2Col) {
3925+
if (IsIm2Col) {
3926+
switch (Dim) {
3927+
case 3:
3928+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL);
3929+
case 4:
3930+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL);
3931+
case 5:
3932+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL);
3933+
default:
3934+
llvm_unreachable("Invalid Dimension in im2col mode for "
3935+
"GetCpAsyncBulkTensorS2GOpcode.");
3936+
}
3937+
} else {
3938+
switch (Dim) {
3939+
case 1:
3940+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE);
3941+
case 2:
3942+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE);
3943+
case 3:
3944+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE);
3945+
case 4:
3946+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE);
3947+
case 5:
3948+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE);
3949+
default:
3950+
llvm_unreachable(
3951+
"Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
3952+
}
3953+
}
3954+
}
3955+
3956+
static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
3957+
bool IsMultiCast,
3958+
bool IsCacheHint, bool IsIm2Col) {
3959+
if (IsIm2Col) {
3960+
switch (Dim) {
3961+
case 3:
3962+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL);
3963+
case 4:
3964+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL);
3965+
case 5:
3966+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL);
3967+
default:
3968+
llvm_unreachable("Invalid Dimension in im2col mode for "
3969+
"GetCpAsyncBulkTensorG2SOpcode.");
3970+
}
3971+
} else {
3972+
switch (Dim) {
3973+
case 1:
3974+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE);
3975+
case 2:
3976+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE);
3977+
case 3:
3978+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE);
3979+
case 4:
3980+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE);
3981+
case 5:
3982+
GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE);
3983+
default:
3984+
llvm_unreachable(
3985+
"Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode.");
3986+
}
3987+
}
3988+
}
3989+
3990+
void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2G(SDNode *N) {
3991+
unsigned int SharedPointerSize =
3992+
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
3993+
bool IsShared32 = (SharedPointerSize == 32);
3994+
3995+
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
3996+
size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
3997+
3998+
ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
3999+
nvvm::CpAsyncBulkTensorFlags Flags;
4000+
Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
4001+
bool IsCacheHint = Flags.U.CacheHint == 1;
4002+
bool IsIm2Col = Flags.U.LoadMode == 1;
4003+
4004+
SDLoc DL(N);
4005+
// List of operands that are common to both variants
4006+
SmallVector<SDValue, 4> Ops{
4007+
N->getOperand(3), // Src pointer in smem
4008+
N->getOperand(4), // Dst tensor_map pointer in gmem
4009+
};
4010+
4011+
// Tensor Dims from [1-5] followed by the cache-hint operand
4012+
size_t TensorDimsStartIndex = 5;
4013+
size_t CacheHintIndex = TensorDimsStartIndex + NumDims;
4014+
for (size_t i = 0; i < NumDims; i++)
4015+
Ops.push_back(N->getOperand(TensorDimsStartIndex + i));
4016+
4017+
// Push the cache-hint operand, if available
4018+
if (IsCacheHint)
4019+
Ops.push_back(N->getOperand(CacheHintIndex));
4020+
4021+
// Finally, the chain operand
4022+
Ops.push_back(N->getOperand(0));
4023+
4024+
unsigned Opcode =
4025+
GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
4026+
4027+
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
4028+
}
4029+
4030+
void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2S(SDNode *N) {
4031+
unsigned int SharedPointerSize =
4032+
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED);
4033+
bool IsShared32 = (SharedPointerSize == 32);
4034+
4035+
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
4036+
size_t NumDims = GetCpAsyncBulkTensorDimFromIntrinsic(IID);
4037+
4038+
ConstantSDNode *FlagsNode = cast<ConstantSDNode>(N->getOperand(2));
4039+
nvvm::CpAsyncBulkTensorFlags Flags;
4040+
Flags.V = static_cast<unsigned>(FlagsNode->getZExtValue());
4041+
bool IsCacheHint = Flags.U.CacheHint == 1;
4042+
bool IsMultiCast = Flags.U.MultiCast == 1;
4043+
bool IsIm2Col = Flags.U.LoadMode == 1;
4044+
4045+
if (IsIm2Col && NumDims < 3)
4046+
report_fatal_error("NumDims should be at least 3 for Im2Col mode");
4047+
4048+
SDLoc DL(N);
4049+
// List of operands that are common to both tile and im2col variants
4050+
SmallVector<SDValue, 4> Ops{
4051+
N->getOperand(3), // Dst pointer in smem
4052+
N->getOperand(4), // Mbarrier pointer in smem
4053+
N->getOperand(5), // Src pointer (i.e. tensor_map) in gmem
4054+
};
4055+
4056+
// Tensor Dims from [1-5]
4057+
size_t TensorDimsStartIndex = 6;
4058+
for (size_t i = 0; i < NumDims; i++)
4059+
Ops.push_back(N->getOperand(TensorDimsStartIndex + i));
4060+
4061+
// Im2Col co-ordinates:
4062+
// These are always present in the input arguments for TensorDims{3,4,5}.
4063+
// Number of values is (NumDims - 2).
4064+
size_t Im2ColStartIndex = TensorDimsStartIndex + NumDims;
4065+
size_t NumDimsIm2Col = (NumDims > 2) ? (NumDims - 2) : 0;
4066+
size_t Im2ColEndIndex = Im2ColStartIndex + NumDimsIm2Col;
4067+
// ...However, passed down to the actual NVPTX only when
4068+
// this mode is enabled.
4069+
if (IsIm2Col) {
4070+
for (size_t i = 0; i < NumDimsIm2Col; i++)
4071+
Ops.push_back(N->getOperand(Im2ColStartIndex + i));
4072+
}
4073+
4074+
// Push MultiCast operand, if available
4075+
if (IsMultiCast)
4076+
Ops.push_back(N->getOperand(Im2ColEndIndex));
4077+
4078+
// Push CacheHint operand, if available
4079+
if (IsCacheHint)
4080+
Ops.push_back(N->getOperand(Im2ColEndIndex + 1));
4081+
4082+
// Finally, the chain operand
4083+
Ops.push_back(N->getOperand(0));
4084+
4085+
unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode(
4086+
NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col);
4087+
4088+
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
4089+
}
4090+
4091+
bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
4092+
unsigned IID = N->getConstantOperandVal(1);
4093+
switch (IID) {
4094+
default:
4095+
return false;
4096+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_1d:
4097+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_2d:
4098+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_3d:
4099+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_4d:
4100+
case Intrinsic::nvvm_cp_async_bulk_tensor_smem_to_gmem_5d:
4101+
SelectCpAsyncBulkTensorS2G(N);
4102+
return true;
4103+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_1d:
4104+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_2d:
4105+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_3d:
4106+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_4d:
4107+
case Intrinsic::nvvm_cp_async_bulk_tensor_gmem_to_smem_5d:
4108+
SelectCpAsyncBulkTensorG2S(N);
4109+
return true;
4110+
}
4111+
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
5757
void Select(SDNode *N) override;
5858
bool tryIntrinsicNoChain(SDNode *N);
5959
bool tryIntrinsicChain(SDNode *N);
60+
bool tryIntrinsicVoid(SDNode *N);
6061
void SelectTexSurfHandle(SDNode *N);
6162
bool tryLoad(SDNode *N);
6263
bool tryLoadVector(SDNode *N);
@@ -74,6 +75,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7475
bool SelectSETP_F16X2(SDNode *N);
7576
bool SelectSETP_BF16X2(SDNode *N);
7677
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
78+
void SelectCpAsyncBulkTensorS2G(SDNode *N);
79+
void SelectCpAsyncBulkTensorG2S(SDNode *N);
7780

7881
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
7982
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);

0 commit comments

Comments
 (0)