Skip to content

Commit 28b7e49

Browse files
AMDGPU/GFX12: Add new dot4 fp8/bf8 instructions (#77892)
Endoding is VOP3P. Tagged as deep/machine learning instructions. i32 type (v4fp8 or v4bf8 packed in i32) is used for src0 and src1. src0 and src1 have no src_modifiers. src2 is f32 and has src_modifiers: f32 fneg(neg_lo[2]) and f32 fabs(neg_hi[2]). --------- Co-authored-by: Petar Avramovic <[email protected]>
1 parent 18d0a7e commit 28b7e49

19 files changed

+938
-12
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ TARGET_BUILTIN(__builtin_amdgcn_sudot4, "iIbiIbiiIb", "nc", "dot8-insts")
255255
TARGET_BUILTIN(__builtin_amdgcn_sdot8, "SiSiSiSiIb", "nc", "dot1-insts")
256256
TARGET_BUILTIN(__builtin_amdgcn_udot8, "UiUiUiUiIb", "nc", "dot7-insts")
257257
TARGET_BUILTIN(__builtin_amdgcn_sudot8, "iIbiIbiiIb", "nc", "dot8-insts")
258+
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_fp8_bf8, "fUiUif", "nc", "gfx12-insts")
259+
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_bf8_fp8, "fUiUif", "nc", "gfx12-insts")
260+
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_fp8_fp8, "fUiUif", "nc", "gfx12-insts")
261+
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_bf8_bf8, "fUiUif", "nc", "gfx12-insts")
258262

259263
//===----------------------------------------------------------------------===//
260264
// GFX10+ only builtins.

clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-err.cl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,9 @@ kernel void builtins_amdgcn_dl_insts_err(
4949

5050
iOut[3] = __builtin_amdgcn_sudot8(false, A, true, B, C, false); // expected-error {{'__builtin_amdgcn_sudot8' needs target feature dot8-insts}}
5151
iOut[4] = __builtin_amdgcn_sudot8(true, A, false, B, C, true); // expected-error {{'__builtin_amdgcn_sudot8' needs target feature dot8-insts}}
52+
53+
fOut[5] = __builtin_amdgcn_dot4_f32_fp8_bf8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_fp8_bf8' needs target feature gfx12-insts}}
54+
fOut[6] = __builtin_amdgcn_dot4_f32_bf8_fp8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_bf8_fp8' needs target feature gfx12-insts}}
55+
fOut[7] = __builtin_amdgcn_dot4_f32_fp8_fp8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_fp8_fp8' needs target feature gfx12-insts}}
56+
fOut[8] = __builtin_amdgcn_dot4_f32_bf8_bf8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_bf8_bf8' needs target feature gfx12-insts}}
5257
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// REQUIRES: amdgpu-registered-target
2+
3+
// RUN: %clang_cc1 -triple amdgcn-unknown-unknown -target-cpu gfx1200 -S -emit-llvm -o - %s | FileCheck %s
4+
5+
typedef unsigned int uint;
6+
7+
// CHECK-LABEL: @builtins_amdgcn_dl_insts
8+
// CHECK: call float @llvm.amdgcn.dot4.f32.fp8.bf8(i32 %uiA, i32 %uiB, float %fC)
9+
// CHECK: call float @llvm.amdgcn.dot4.f32.bf8.fp8(i32 %uiA, i32 %uiB, float %fC)
10+
// CHECK: call float @llvm.amdgcn.dot4.f32.fp8.fp8(i32 %uiA, i32 %uiB, float %fC)
11+
// CHECK: call float @llvm.amdgcn.dot4.f32.bf8.bf8(i32 %uiA, i32 %uiB, float %fC)
12+
13+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
14+
kernel void builtins_amdgcn_dl_insts_err(global float *fOut,
15+
uint uiA, uint uiB, float fC) {
16+
fOut[0] = __builtin_amdgcn_dot4_f32_fp8_bf8(uiA, uiB, fC);
17+
fOut[1] = __builtin_amdgcn_dot4_f32_bf8_fp8(uiA, uiB, fC);
18+
fOut[2] = __builtin_amdgcn_dot4_f32_fp8_fp8(uiA, uiB, fC);
19+
fOut[3] = __builtin_amdgcn_dot4_f32_bf8_bf8(uiA, uiB, fC);
20+
}

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,6 +2729,25 @@ def int_amdgcn_udot8 :
27292729
ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<5>>]
27302730
>;
27312731

2732+
// f32 %r = llvm.amdgcn.dot4.f32.type_a.type_b (v4type_a (as i32) %a, v4type_b (as i32) %b, f32 %c)
2733+
// %r = %a[0] * %b[0] + %a[1] * %b[1] + %a[2] * %b[2] + %a[3] * %b[3] + %c
2734+
class AMDGPU8bitFloatDot4Intrinsic :
2735+
ClangBuiltin<!subst("int", "__builtin", NAME)>,
2736+
DefaultAttrsIntrinsic<
2737+
[llvm_float_ty], // %r
2738+
[
2739+
llvm_i32_ty, // %a
2740+
llvm_i32_ty, // %b
2741+
llvm_float_ty, // %c
2742+
],
2743+
[IntrNoMem, IntrSpeculatable]
2744+
>;
2745+
2746+
def int_amdgcn_dot4_f32_fp8_bf8 : AMDGPU8bitFloatDot4Intrinsic;
2747+
def int_amdgcn_dot4_f32_bf8_fp8 : AMDGPU8bitFloatDot4Intrinsic;
2748+
def int_amdgcn_dot4_f32_fp8_fp8 : AMDGPU8bitFloatDot4Intrinsic;
2749+
def int_amdgcn_dot4_f32_bf8_bf8 : AMDGPU8bitFloatDot4Intrinsic;
2750+
27322751
//===----------------------------------------------------------------------===//
27332752
// gfx908 intrinsics
27342753
// ===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4491,6 +4491,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
44914491
case Intrinsic::amdgcn_fdot2_f32_bf16:
44924492
case Intrinsic::amdgcn_sudot4:
44934493
case Intrinsic::amdgcn_sudot8:
4494+
case Intrinsic::amdgcn_dot4_f32_fp8_bf8:
4495+
case Intrinsic::amdgcn_dot4_f32_bf8_fp8:
4496+
case Intrinsic::amdgcn_dot4_f32_fp8_fp8:
4497+
case Intrinsic::amdgcn_dot4_f32_bf8_bf8:
44944498
case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16:
44954499
case Intrinsic::amdgcn_wmma_f16_16x16x16_f16:
44964500
case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied:

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
16881688
bool validateMIMGD16(const MCInst &Inst);
16891689
bool validateMIMGMSAA(const MCInst &Inst);
16901690
bool validateOpSel(const MCInst &Inst);
1691+
bool validateNeg(const MCInst &Inst, int OpName);
16911692
bool validateDPP(const MCInst &Inst, const OperandVector &Operands);
16921693
bool validateVccOperand(unsigned Reg) const;
16931694
bool validateVOPLiteral(const MCInst &Inst, const OperandVector &Operands);
@@ -4357,6 +4358,41 @@ bool AMDGPUAsmParser::validateOpSel(const MCInst &Inst) {
43574358
return true;
43584359
}
43594360

4361+
bool AMDGPUAsmParser::validateNeg(const MCInst &Inst, int OpName) {
4362+
assert(OpName == AMDGPU::OpName::neg_lo || OpName == AMDGPU::OpName::neg_hi);
4363+
4364+
const unsigned Opc = Inst.getOpcode();
4365+
uint64_t TSFlags = MII.get(Opc).TSFlags;
4366+
4367+
// v_dot4 fp8/bf8 neg_lo/neg_hi not allowed on src0 and src1 (allowed on src2)
4368+
if (!(TSFlags & SIInstrFlags::IsDOT))
4369+
return true;
4370+
4371+
int NegIdx = AMDGPU::getNamedOperandIdx(Opc, OpName);
4372+
if (NegIdx == -1)
4373+
return true;
4374+
4375+
unsigned Neg = Inst.getOperand(NegIdx).getImm();
4376+
4377+
// Instructions that have neg_lo or neg_hi operand but neg modifier is allowed
4378+
// on some src operands but not allowed on other.
4379+
// It is convenient that such instructions don't have src_modifiers operand
4380+
// for src operands that don't allow neg because they also don't allow opsel.
4381+
4382+
int SrcMods[3] = {AMDGPU::OpName::src0_modifiers,
4383+
AMDGPU::OpName::src1_modifiers,
4384+
AMDGPU::OpName::src2_modifiers};
4385+
4386+
for (unsigned i = 0; i < 3; ++i) {
4387+
if (!AMDGPU::hasNamedOperand(Opc, SrcMods[i])) {
4388+
if (Neg & (1 << i))
4389+
return false;
4390+
}
4391+
}
4392+
4393+
return true;
4394+
}
4395+
43604396
bool AMDGPUAsmParser::validateDPP(const MCInst &Inst,
43614397
const OperandVector &Operands) {
43624398
const unsigned Opc = Inst.getOpcode();
@@ -4834,6 +4870,16 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
48344870
"invalid op_sel operand");
48354871
return false;
48364872
}
4873+
if (!validateNeg(Inst, AMDGPU::OpName::neg_lo)) {
4874+
Error(getImmLoc(AMDGPUOperand::ImmTyNegLo, Operands),
4875+
"invalid neg_lo operand");
4876+
return false;
4877+
}
4878+
if (!validateNeg(Inst, AMDGPU::OpName::neg_hi)) {
4879+
Error(getImmLoc(AMDGPUOperand::ImmTyNegHi, Operands),
4880+
"invalid neg_hi operand");
4881+
return false;
4882+
}
48374883
if (!validateDPP(Inst, Operands)) {
48384884
return false;
48394885
}

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,14 +1260,19 @@ void AMDGPUInstPrinter::printPackedModifier(const MCInst *MI,
12601260
int NumOps = 0;
12611261
int Ops[3];
12621262

1263-
for (int OpName : { AMDGPU::OpName::src0_modifiers,
1264-
AMDGPU::OpName::src1_modifiers,
1265-
AMDGPU::OpName::src2_modifiers }) {
1266-
int Idx = AMDGPU::getNamedOperandIdx(Opc, OpName);
1267-
if (Idx == -1)
1263+
std::pair<int, int> MOps[] = {
1264+
{AMDGPU::OpName::src0_modifiers, AMDGPU::OpName::src0},
1265+
{AMDGPU::OpName::src1_modifiers, AMDGPU::OpName::src1},
1266+
{AMDGPU::OpName::src2_modifiers, AMDGPU::OpName::src2}};
1267+
int DefaultValue = (Mod == SISrcMods::OP_SEL_1);
1268+
1269+
for (auto [SrcMod, Src] : MOps) {
1270+
if (!AMDGPU::hasNamedOperand(Opc, Src))
12681271
break;
12691272

1270-
Ops[NumOps++] = MI->getOperand(Idx).getImm();
1273+
int ModIdx = AMDGPU::getNamedOperandIdx(Opc, SrcMod);
1274+
Ops[NumOps++] =
1275+
(ModIdx != -1) ? MI->getOperand(ModIdx).getImm() : DefaultValue;
12711276
}
12721277

12731278
const bool HasDstSel =

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,48 @@ def : GCNPat < (int_amdgcn_sdot4 i32:$src0,
443443
>;
444444
} // End SubtargetPredicate = HasDot8Insts
445445

446+
// Does not use opsel, no src_modifiers on src0 and src1.
447+
// src_modifiers on src2(f32) are f32 fneg(neg_lo[2]) and f32 fabs(neg_hi[2]).
448+
def VOP3P_DOTF8_Profile : VOP3P_Profile<VOPProfile <[f32, i32, i32, f32]>,
449+
VOP3_PACKED, 1> {
450+
let HasClamp = 0;
451+
let HasOpSel = 0;
452+
let HasOMod = 0;
453+
let IsDOT = 1;
454+
let HasSrc0Mods = 0;
455+
let HasSrc1Mods = 0;
456+
let HasSrc2Mods = 1;
457+
458+
let InsVOP3P = (ins VSrc_b32:$src0, VSrc_b32:$src1,
459+
PackedF16InputMods:$src2_modifiers, VSrc_f32:$src2,
460+
neg_lo0:$neg_lo, neg_hi0:$neg_hi);
461+
462+
let InsVOP3DPP8 = (ins DstRC:$old, VGPR_32:$src0, VRegSrc_32:$src1,
463+
PackedF16InputMods:$src2_modifiers, VRegSrc_32:$src2,
464+
neg_lo0:$neg_lo, neg_hi0:$neg_hi, dpp8:$dpp8, FI:$fi);
465+
466+
let InsVOP3DPP16 = (ins DstRC:$old, VGPR_32:$src0, VRegSrc_32:$src1,
467+
PackedF16InputMods:$src2_modifiers, VRegSrc_32:$src2,
468+
neg_lo0:$neg_lo, neg_hi0:$neg_hi, dpp_ctrl:$dpp_ctrl,
469+
row_mask:$row_mask, bank_mask:$bank_mask,
470+
bound_ctrl:$bound_ctrl, FI:$fi);
471+
}
472+
473+
multiclass VOP3PDOTF8Inst <string OpName, SDPatternOperator intrinsic_node> {
474+
defm NAME : VOP3PInst<OpName, VOP3P_DOTF8_Profile, null_frag, 1>;
475+
476+
let SubtargetPredicate = isGFX12Plus in
477+
def : GCNPat <(intrinsic_node i32:$src0, i32:$src1,
478+
(VOP3Mods f32:$src2, i32:$src2_modifiers)),
479+
(!cast<Instruction>(NAME) i32:$src0, i32:$src1,
480+
i32:$src2_modifiers, f32:$src2)>;
481+
}
482+
483+
defm V_DOT4_F32_FP8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_fp8_bf8", int_amdgcn_dot4_f32_fp8_bf8>;
484+
defm V_DOT4_F32_BF8_FP8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_fp8", int_amdgcn_dot4_f32_bf8_fp8>;
485+
defm V_DOT4_F32_FP8_FP8 : VOP3PDOTF8Inst<"v_dot4_f32_fp8_fp8", int_amdgcn_dot4_f32_fp8_fp8>;
486+
defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f32_bf8_bf8>;
487+
446488
def : UDot2Pat<V_DOT2_U32_U16>;
447489
def : SDot2Pat<V_DOT2_I32_I16>;
448490

@@ -1019,6 +1061,11 @@ defm V_PK_MAX_NUM_F16 : VOP3P_Real_with_name_gfx12<0x1c, "V_PK_MAX_F16", "v_pk_m
10191061
defm V_PK_MINIMUM_F16 : VOP3P_Real_gfx12<0x1d>;
10201062
defm V_PK_MAXIMUM_F16 : VOP3P_Real_gfx12<0x1e>;
10211063

1064+
defm V_DOT4_F32_FP8_BF8 : VOP3P_Realtriple<GFX12Gen, 0x24>;
1065+
defm V_DOT4_F32_BF8_FP8 : VOP3P_Realtriple<GFX12Gen, 0x25>;
1066+
defm V_DOT4_F32_FP8_FP8 : VOP3P_Realtriple<GFX12Gen, 0x26>;
1067+
defm V_DOT4_F32_BF8_BF8 : VOP3P_Realtriple<GFX12Gen, 0x27>;
1068+
10221069
//===----------------------------------------------------------------------===//
10231070
// GFX11
10241071
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/VOPInstructions.td

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ class VOP3_Pseudo <string opName, VOPProfile P, list<dag> pattern = [],
169169
class VOP3P_Pseudo <string opName, VOPProfile P, list<dag> pattern = []> :
170170
VOP3_Pseudo<opName, P, pattern, 1> {
171171
let VOP3P = 1;
172+
let IsDOT = P.IsDOT;
172173
}
173174

174175
class VOP_Real<VOP_Pseudo ps> {
@@ -387,7 +388,7 @@ class VOP3Pe <bits<7> op, VOPProfile P> : Enc64 {
387388
let Inst{12} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{2}, 0); // op_sel(1)
388389
let Inst{13} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{2}, 0); // op_sel(2)
389390

390-
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, ?); // op_sel_hi(2)
391+
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(2)
391392

392393
let Inst{15} = !if(P.HasClamp, clamp{0}, 0);
393394

@@ -396,8 +397,8 @@ class VOP3Pe <bits<7> op, VOPProfile P> : Enc64 {
396397
let Inst{40-32} = !if(P.HasSrc0, src0, 0);
397398
let Inst{49-41} = !if(P.HasSrc1, src1, 0);
398399
let Inst{58-50} = !if(P.HasSrc2, src2, 0);
399-
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, ?); // op_sel_hi(0)
400-
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, ?); // op_sel_hi(1)
400+
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(0)
401+
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(1)
401402
let Inst{61} = !if(P.HasSrc0Mods, src0_modifiers{0}, 0); // neg (lo)
402403
let Inst{62} = !if(P.HasSrc1Mods, src1_modifiers{0}, 0); // neg (lo)
403404
let Inst{63} = !if(P.HasSrc2Mods, src2_modifiers{0}, 0); // neg (lo)
@@ -772,12 +773,12 @@ class VOP3P_DPPe_Common_Base<bits<7> op, VOPProfile P> : Enc96 {
772773
let Inst{11} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{2}, 0); // op_sel(0)
773774
let Inst{12} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{2}, 0); // op_sel(1)
774775
let Inst{13} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{2}, 0); // op_sel(2)
775-
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, ?); // op_sel_hi(2)
776+
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(2)
776777
let Inst{15} = !if(P.HasClamp, clamp{0}, 0);
777778
let Inst{22-16} = op;
778779
let Inst{31-23} = 0x198; // encoding
779-
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, ?); // op_sel_hi(0)
780-
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, ?); // op_sel_hi(1)
780+
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(0)
781+
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(1)
781782
let Inst{61} = !if(P.HasSrc0Mods, src0_modifiers{0}, 0); // neg (lo)
782783
let Inst{62} = !if(P.HasSrc1Mods, src1_modifiers{0}, 0); // neg (lo)
783784
let Inst{63} = !if(P.HasSrc2Mods, src2_modifiers{0}, 0); // neg (lo)

0 commit comments

Comments
 (0)