Skip to content

[X86][AMX] Fix handling of AMX-FP8 internal intrinsics #123540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 21, 2025
Merged

Conversation

fzou1
Copy link
Contributor

@fzou1 fzou1 commented Jan 20, 2025

This is to fix #123410.

@llvmbot
Copy link
Member

llvmbot commented Jan 20, 2025

@llvm/pr-subscribers-backend-x86

Author: Feng Zou (fzou1)

Changes

This is to fix #123410.


Full diff: https://github.com/llvm/llvm-project/pull/123540.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86LowerAMXType.cpp (+5-1)
  • (modified) llvm/test/CodeGen/X86/amx-fp8-internal.ll (+95)
diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index cd5813a5338eaf..41cf0fc2cef4fc 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -248,7 +248,11 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
   case Intrinsic::x86_tdpbuud_internal:
   case Intrinsic::x86_tdpbf16ps_internal:
   case Intrinsic::x86_tdpfp16ps_internal:
-  case Intrinsic::x86_tmmultf32ps_internal: {
+  case Intrinsic::x86_tmmultf32ps_internal:
+  case Intrinsic::x86_tdpbf8ps_internal:
+  case Intrinsic::x86_tdpbhf8ps_internal:
+  case Intrinsic::x86_tdphbf8ps_internal:
+  case Intrinsic::x86_tdphf8ps_internal: {
     switch (OpNo) {
     case 3:
       Row = II->getArgOperand(0);
diff --git a/llvm/test/CodeGen/X86/amx-fp8-internal.ll b/llvm/test/CodeGen/X86/amx-fp8-internal.ll
index ade71499b361a8..c0bdcdbadaa335 100644
--- a/llvm/test/CodeGen/X86/amx-fp8-internal.ll
+++ b/llvm/test/CodeGen/X86/amx-fp8-internal.ll
@@ -59,6 +59,101 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
   ret void
 }
 
+%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }>
+
+define dso_local void @__tile_dpbf8ps(ptr nocapture noundef %dst, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src1, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src2) {
+; CHECK-LABEL: __tile_dpbf8ps:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    pushq %rbp
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    .cfi_offset %rbp, -16
+; CHECK-NEXT:    movq %rsp, %rbp
+; CHECK-NEXT:    .cfi_def_cfa_register %rbp
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    andq $-1024, %rsp # imm = 0xFC00
+; CHECK-NEXT:    subq $5120, %rsp # imm = 0x1400
+; CHECK-NEXT:    .cfi_offset %rbx, -24
+; CHECK-NEXT:    vxorps %xmm0, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %zmm0, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb $1, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movzwl 16(%rbp), %eax
+; CHECK-NEXT:    movb %al, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movswq 1106(%rbp), %rcx
+; CHECK-NEXT:    movw %cx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %cx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %cx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movswq 18(%rbp), %rdx
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movzwl %dx, %esi
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    shrl $2, %esi
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    addq $64, %rdi
+; CHECK-NEXT:    tileloadd (%rdi,%rcx), %tmm0
+; CHECK-NEXT:    leaq 80(%rbp), %r8
+; CHECK-NEXT:    tileloadd (%r8,%rdx), %tmm1
+; CHECK-NEXT:    leaq 1168(%rbp), %r8
+; CHECK-NEXT:    tileloadd (%r8,%rcx), %tmm2
+; CHECK-NEXT:    movabsq $64, %rbx
+; CHECK-NEXT:    tilestored %tmm0, 1024(%rsp,%rbx) # 1024-byte Folded Spill
+; CHECK-NEXT:    tileloadd 1024(%rsp,%rbx), %tmm3 # 1024-byte Folded Reload
+; CHECK-NEXT:    tdpbf8ps %tmm2, %tmm1, %tmm3
+; CHECK-NEXT:    tilestored %tmm3, (%rdi,%rcx)
+; CHECK-NEXT:    tilestored %tmm0, 2048(%rsp,%rbx) # 1024-byte Folded Spill
+; CHECK-NEXT:    tileloadd 2048(%rsp,%rbx), %tmm3 # 1024-byte Folded Reload
+; CHECK-NEXT:    tdpbhf8ps %tmm2, %tmm1, %tmm3
+; CHECK-NEXT:    tilestored %tmm3, (%rdi,%rcx)
+; CHECK-NEXT:    tilestored %tmm0, 3072(%rsp,%rbx) # 1024-byte Folded Spill
+; CHECK-NEXT:    tileloadd 3072(%rsp,%rbx), %tmm3 # 1024-byte Folded Reload
+; CHECK-NEXT:    tdphbf8ps %tmm2, %tmm1, %tmm3
+; CHECK-NEXT:    tilestored %tmm3, (%rdi,%rcx)
+; CHECK-NEXT:    tdphf8ps %tmm2, %tmm1, %tmm0
+; CHECK-NEXT:    tilestored %tmm0, (%rdi,%rcx)
+; CHECK-NEXT:    leaq -8(%rbp), %rsp
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    popq %rbp
+; CHECK-NEXT:    .cfi_def_cfa %rsp, 8
+; CHECK-NEXT:    tilerelease
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = load i16, ptr %src1, align 64
+  %col = getelementptr inbounds nuw i8, ptr %src2, i64 2
+  %1 = load i16, ptr %col, align 2
+  %col1 = getelementptr inbounds nuw i8, ptr %src1, i64 2
+  %2 = load i16, ptr %col1, align 2
+  %tile = getelementptr inbounds nuw i8, ptr %dst, i64 64
+  %3 = load <256 x i32>, ptr %tile, align 64
+  %tile2 = getelementptr inbounds nuw i8, ptr %src1, i64 64
+  %4 = load <256 x i32>, ptr %tile2, align 64
+  %tile3 = getelementptr inbounds nuw i8, ptr %src2, i64 64
+  %5 = load <256 x i32>, ptr %tile3, align 64
+  %6 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %3)
+  %7 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %4)
+  %8 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %5)
+
+  %9 = tail call x86_amx @llvm.x86.tdpbf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+  %10 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %9)
+  store <256 x i32> %10, ptr %tile, align 64
+
+  %11 = tail call x86_amx @llvm.x86.tdpbhf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+  %12 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %11)
+  store <256 x i32> %12, ptr %tile, align 64
+
+  %13 = tail call x86_amx @llvm.x86.tdphbf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+  %14 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %13)
+  store <256 x i32> %14, ptr %tile, align 64
+
+  %15 = tail call x86_amx @llvm.x86.tdphf8ps.internal(i16 %0, i16 %1, i16 %2, x86_amx %6, x86_amx %7, x86_amx %8)
+  %16 = tail call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %15)
+  store <256 x i32> %16, ptr %tile, align 64
+
+  ret void
+}
+
 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
 declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
 declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@@ -59,6 +59,101 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
ret void
}

%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }>

define dso_local void @__tile_dpbf8ps(ptr nocapture noundef %dst, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src1, ptr nocapture noundef readonly byval(%struct.__tile1024i_str) align 64 %src2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove nocapture, noundef, readonly, byval(%struct.__tile1024i_str)
Add nounwind

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@fzou1 fzou1 merged commit abbfed9 into llvm:main Jan 21, 2025
8 checks passed
@fzou1 fzou1 deleted the amx-fp8-fix branch January 21, 2025 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

UNREACHABLE executed at /root/llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp:223!
3 participants