-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[X86][AMX] Fix handling of AMX-FP8 internal intrinsics #123540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-x86 Author: Feng Zou (fzou1) ChangesThis is to fix #123410. Full diff: https://github.com/llvm/llvm-project/pull/123540.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index cd5813a5338eaf..41cf0fc2cef4fc 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -248,7 +248,11 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
case Intrinsic::x86_tdpbuud_internal:
case Intrinsic::x86_tdpbf16ps_internal:
case Intrinsic::x86_tdpfp16ps_internal:
- case Intrinsic::x86_tmmultf32ps_internal: {
+ case Intrinsic::x86_tmmultf32ps_internal:
+ case Intrinsic::x86_tdpbf8ps_internal:
+ case Intrinsic::x86_tdpbhf8ps_internal:
+ case Intrinsic::x86_tdphbf8ps_internal:
+ case Intrinsic::x86_tdphf8ps_internal: {
switch (OpNo) {
case 3:
Row = II->getArgOperand(0);
diff --git a/llvm/test/CodeGen/X86/amx-fp8-internal.ll b/llvm/test/CodeGen/X86/amx-fp8-internal.ll
index ade71499b361a8..c0bdcdbadaa335 100644
--- a/llvm/test/CodeGen/X86/amx-fp8-internal.ll
+++ b/llvm/test/CodeGen/X86/amx-fp8-internal.ll
@@ -59,6 +59,101 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
ret void
}
+%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }>
+
+define dso_local void @__tile_dpbf8ps(ptr nocapture noundef %dst, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src1, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src2) {
+; CHECK-LABEL: __tile_dpbf8ps:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: pushq %rbp
+; CHECK-NEXT: .cfi_def_cfa_offset 16
+; CHECK-NEXT: .cfi_offset %rbp, -16
+; CHECK-NEXT: movq %rsp, %rbp
+; CHECK-NEXT: .cfi_def_cfa_register %rbp
+; CHECK-NEXT: pushq %rbx
+; CHECK-NEXT: andq $-1024, %rsp # imm = 0xFC00
+; CHECK-NEXT: subq $5120, %rsp # imm = 0x1400
+; CHECK-NEXT: .cfi_offset %rbx, -24
+; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0
+; CHECK-NEXT: vmovups %zmm0, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movzwl 16(%rbp), %eax
+; CHECK-NEXT: movb %al, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movb %al, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movb %al, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movswq 1106(%rbp), %rcx
+; CHECK-NEXT: movw %cx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movw %cx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movw %cx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movswq 18(%rbp), %rdx
+; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: movzwl %dx, %esi
+; CHECK-NEXT: movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: shrl $2, %esi
+; CHECK-NEXT: movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp)
+; CHECK-NEXT: addq $64, %rdi
+; CHECK-NEXT: tileloadd (%rdi,%rcx), %tmm0
+; CHECK-NEXT: leaq 80(%rbp), %r8
+; CHECK-NEXT: tileloadd (%r8,%rdx), %tmm1
+; CHECK-NEXT: leaq 1168(%rbp), %r8
+; CHECK-NEXT: tileloadd (%r8,%rcx), %tmm2
+; CHECK-NEXT: movabsq $64, %rbx
+; CHECK-NEXT: tilestored %tmm0, 1024(%rsp,%rbx) # 1024-byte Folded Spill
+; CHECK-NEXT: tileloadd 1024(%rsp,%rbx), %tmm3 # 1024-byte Folded Reload
+; CHECK-NEXT: tdpbf8ps %tmm2, %tmm1, %tmm3
+; CHECK-NEXT: tilestored %tmm3, (%rdi,%rcx)
+; CHECK-NEXT: tilestored %tmm0, 2048(%rsp,%rbx) # 1024-byte Folded Spill
+; CHECK-NEXT: tileloadd 2048(%rsp,%rbx), %tmm3 # 1024-byte Folded Reload
+; CHECK-NEXT: tdpbhf8ps %tmm2, %tmm1, %tmm3
+; CHECK-NEXT: tilestored %tmm3, (%rdi,%rcx)
+; CHECK-NEXT: tilestored %tmm0, 3072(%rsp,%rbx) # 1024-byte Folded Spill
+; CHECK-NEXT: tileloadd 3072(%rsp,%rbx), %tmm3 # 1024-byte Folded Reload
+; CHECK-NEXT: tdphbf8ps %tmm2, %tmm1, %tmm3
+; CHECK-NEXT: tilestored %tmm3, (%rdi,%rcx)
+; CHECK-NEXT: tdphf8ps %tmm2, %tmm1, %tmm0
+; CHECK-NEXT: tilestored %tmm0, (%rdi,%rcx)
+; CHECK-NEXT: leaq -8(%rbp), %rsp
+; CHECK-NEXT: popq %rbx
+; CHECK-NEXT: popq %rbp
+; CHECK-NEXT: .cfi_def_cfa %rsp, 8
+; CHECK-NEXT: tilerelease
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: retq
+entry:
+ %0 = load i16, ptr %src1, align 64
+ %col = getelementptr inbounds nuw i8, ptr %src2, i64 2
+ %1 = load i16, ptr %col, align 2
+ %col1 = getelementptr inbounds nuw i8, ptr %src1, i64 2
+ %2 = load i16, ptr %col1, align 2
+ %tile = getelementptr inbounds nuw i8, ptr %dst, i64 64
+ %3 = load <256 x i32>, ptr %tile, align 64
+ %tile2 = getelementptr inbounds nuw i8, ptr %src1, i64 64
+ %4 = load <256 x i32>, ptr %tile2, align 64
+ %tile3 = getelementptr inbounds nuw i8, ptr %src2, i64 64
+ %5 = load <256 x i32>, ptr %tile3, align 64
+ %6 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %3)
+ %7 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %4)
+ %8 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %5)
+
+ %9 = tail call x86_amx @llvm.x86.tdpbf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+ %10 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %9)
+ store <256 x i32> %10, ptr %tile, align 64
+
+ %11 = tail call x86_amx @llvm.x86.tdpbhf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+ %12 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %11)
+ store <256 x i32> %12, ptr %tile, align 64
+
+ %13 = tail call x86_amx @llvm.x86.tdphbf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+ %14 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %13)
+ store <256 x i32> %14, ptr %tile, align 64
+
+ %15 = tail call x86_amx @llvm.x86.tdphf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+ %16 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %15)
+ store <256 x i32> %16, ptr %tile, align 64
+
+ ret void
+}
+
declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@@ -59,6 +59,101 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { | |||
ret void | |||
} | |||
|
|||
%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }> | |||
|
|||
define dso_local void @__tile_dpbf8ps(ptr nocapture noundef %dst, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src1, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove nocapture
, noundef
, readonly
, byval(%struct.__tile1024i_str)
Add nounwind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This is to fix #123410.