Skip to content

[AArch64][SME] Change output class of FORM_TRANSPOSED_REG_TUPLE pseudos #123755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 24, 2025

Conversation

kmclaughlin-arm
Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm commented Jan 21, 2025

The FORM_TRANSPOSED_REG_TUPLE pseudo nodes use either the ZPR2Mul2
or ZPR4Mul4 register classes for output. This patch changes the class
so that these can be extended to other multi-vector intrinsics which
instead create a ZPR2/ZPR4 register sequence.

…l classes

The FORM_TRANSPOSED_REG_TUPLE pseudo nodes use either the ZPR2Mul2 or ZPR4Mul4
register classes for output. To extend these to other multi-vector intrinsics
which instead create a ZPR2/ZPR4 REG_SEQUENCE, a new pseudo has been added and
the existing one renamed.
@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

The FORM_TRANSPOSED_REG_TUPLE pseudo nodes use either the ZPR2Mul2 or
ZPR4Mul4 register classes for output. To extend these to other multi-vector
intrinsics which instead create a ZPR2/ZPR4 register sequence, a new
pseudo has been added and the existing one renamed.


Patch is 46.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123755.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+6-4)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+12)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp (+7-7)
  • (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+16-10)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll (+459)
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index b44c48afe705ba..239070925aa3aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -1755,8 +1755,10 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
          MBB, MBBI, AArch64::ZPR4RegClass, AArch64::ZPR4StridedRegClass,
          AArch64::LDNT1D_4Z, AArch64::LDNT1D_4Z_STRIDED);
    case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
+   case AArch64::FORM_TRANSPOSED_REG_TUPLE_MULX2_PSEUDO:
      return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 2);
    case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
+   case AArch64::FORM_TRANSPOSED_REG_TUPLE_MULX4_PSEUDO:
      return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 4);
   }
   return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d4a114c275fb76..289ada6b4f86f3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8759,7 +8759,7 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
 //   %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
 //   %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
 //   %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
-//   %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
+//   %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_MULX2_PSEUDO %5:zpr, %8:zpr
 //
 bool shouldUseFormStridedPseudo(MachineInstr &MI) {
   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
@@ -8767,9 +8767,11 @@ bool shouldUseFormStridedPseudo(MachineInstr &MI) {
   const TargetRegisterClass *RegClass = nullptr;
   switch (MI.getOpcode()) {
   case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
+  case AArch64::FORM_TRANSPOSED_REG_TUPLE_MULX2_PSEUDO:
     RegClass = &AArch64::ZPR2StridedOrContiguousRegClass;
     break;
   case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
+  case AArch64::FORM_TRANSPOSED_REG_TUPLE_MULX4_PSEUDO:
     RegClass = &AArch64::ZPR4StridedOrContiguousRegClass;
     break;
   default:
@@ -8824,14 +8826,14 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
     }
   }
 
-  if (MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
-      MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) {
+  const AArch64InstrInfo *TII =
+      MI.getMF()->getSubtarget<AArch64Subtarget>().getInstrInfo();
+  if (TII->isFormTransposedOpcode(MI.getOpcode())) {
     // If input values to the FORM_TRANSPOSED_REG_TUPLE pseudo aren't copies
     // from a StridedOrContiguous class, fall back on REG_SEQUENCE node.
     if (shouldUseFormStridedPseudo(MI))
       return;
 
-    const TargetInstrInfo *TII = Subtarget->getInstrInfo();
     MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
                                       TII->get(TargetOpcode::REG_SEQUENCE),
                                       MI.getOperand(0).getReg());
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index e37f70f7d985de..f45d868546e7f4 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -548,6 +548,18 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
                                                Register TargetReg,
                                                bool FrameSetup) const;
 
+  bool isFormTransposedOpcode(unsigned Opc) const {
+    switch (Opc) {
+    case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
+    case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
+    case AArch64::FORM_TRANSPOSED_REG_TUPLE_MULX2_PSEUDO:
+    case AArch64::FORM_TRANSPOSED_REG_TUPLE_MULX4_PSEUDO:
+      return true;
+    default:
+      return false;
+    }
+  }
+
 #define GET_INSTRINFO_HELPER_DECLS
 #include "AArch64GenInstrInfo.inc"
 
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 5973b63b5a8024..773d9946c192f1 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -1109,13 +1109,14 @@ bool AArch64RegisterInfo::getRegAllocationHints(
   // so we add the strided registers as a hint.
   unsigned RegID = MRI.getRegClass(VirtReg)->getID();
   // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
+  const AArch64InstrInfo *TII =
+      MF.getSubtarget<AArch64Subtarget>().getInstrInfo();
   if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
        RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
-      any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
-        return Use.getOpcode() ==
-                   AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
-               Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
-      })) {
+      any_of(MRI.use_nodbg_instructions(VirtReg),
+             [&TII](const MachineInstr &Use) {
+               return TII->isFormTransposedOpcode(Use.getOpcode());
+             })) {
     const TargetRegisterClass *StridedRC =
         RegID == AArch64::ZPR2StridedOrContiguousRegClassID
             ? &AArch64::ZPR2StridedRegClass
@@ -1130,8 +1131,7 @@ bool AArch64RegisterInfo::getRegAllocationHints(
   }
 
   for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
-    if (MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
-        MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
+    if (!TII->isFormTransposedOpcode(MI.getOpcode()))
       return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
                                                        MF, VRM);
 
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 81004e70dc179b..918637a9c6d3cb 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -45,20 +45,26 @@ def am_sme_indexed_b4 : ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0, 15>
 // If the operands do not match this pattern, the pseudos are expanded
 // to a REG_SEQUENCE using the post-isel hook.
 
-def FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO :
-  Pseudo<(outs ZPR2Mul2:$tup),
-         (ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
+class sme_form_transpose_x2_pseudo<RegisterClass multi_vector_class>
+    : Pseudo<(outs multi_vector_class:$tup), (ins ZPR:$zn0, ZPR:$zn1), []>,
+      Sched<[]> {
   let hasSideEffects = 0;
   let hasPostISelHook = 1;
 }
 
-def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO :
-  Pseudo<(outs ZPR4Mul4:$tup),
-         (ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
+def FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO    : sme_form_transpose_x2_pseudo<ZPR2>;
+def FORM_TRANSPOSED_REG_TUPLE_MULX2_PSEUDO : sme_form_transpose_x2_pseudo<ZPR2Mul2>;
+
+class sme_form_transpose_x4_pseudo<RegisterClass multi_vector_class>
+    : Pseudo<(outs multi_vector_class:$tup), (ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>,
+      Sched<[]> {
   let hasSideEffects = 0;
   let hasPostISelHook = 1;
 }
 
+def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO    : sme_form_transpose_x4_pseudo<ZPR4>;
+def FORM_TRANSPOSED_REG_TUPLE_MULX4_PSEUDO : sme_form_transpose_x4_pseudo<ZPR4Mul4>;
+
 def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
 def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
                              [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
@@ -164,14 +170,14 @@ class SME2_ZA_TwoOp_Multi_Single_Pat<string name, SDPatternOperator intrinsic, O
 class SME2_ZA_TwoOp_VG2_Multi_Single_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty,
                                          ValueType vt, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm),
-          (!cast<Instruction>(name # _PSEUDO) $base, $offset, (REG_SEQUENCE ZPR2, vt:$Zn1, zsub0, vt:$Zn2, zsub1),
+          (!cast<Instruction>(name # _PSEUDO) $base, $offset, (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1, vt:$Zn2),
                                               zpr_ty:$Zm)>;
 class SME2_ZA_TwoOp_VG4_Multi_Single_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty,
                                          ValueType vt, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
                      vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm),
           (!cast<Instruction>(name # _PSEUDO) $base, $offset,
-                                              (REG_SEQUENCE ZPR4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3),
+                                              (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
                                               zpr_ty:$Zm)>;
 
 class SME2_ZA_TwoOp_VG2_Multi_Multi_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ValueType vt, ComplexPattern tileslice>
@@ -197,14 +203,14 @@ class SME2_ZA_TwoOp_VG2_Multi_Index_Pat<string name, SDPatternOperator intrinsic
                                         Operand imm_ty, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm, (i32 imm_ty:$i)),
           (!cast<Instruction>(name # _PSEUDO) $base, $offset,
-                                              (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>;
+                                              (FORM_TRANSPOSED_REG_TUPLE_MULX2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>;
 
 class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
                                         Operand imm_ty, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
                      vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm, (i32 imm_ty:$i)),
           (!cast<Instruction>(name # _PSEUDO) $base, $offset,
-                                              (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
+                                              (FORM_TRANSPOSED_REG_TUPLE_MULX4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
                                               zpr_ty:$Zm, imm_ty:$i)>;
 
 class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
diff --git a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
index 86ed63d743713c..967d168593a400 100644
--- a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
@@ -319,6 +319,41 @@ define void @udot_single_za32_u16_vg1x2(i32 %slice, <vscale x 16 x i8> %unused,
   ret void
 }
 
+define void @udot_single_za32_u16_vg1x2_tuple(ptr %ptr, i64 %stride, <vscale x 8 x i16> %zn) #0 {
+; CHECK-LABEL: udot_single_za32_u16_vg1x2_tuple:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-3
+; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    ptrue pn8.b
+; CHECK-NEXT:    add x9, x0, x1
+; CHECK-NEXT:    str z10, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    mov w8, wzr
+; CHECK-NEXT:    str z9, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1h { z1.h, z9.h }, pn8/z, [x0]
+; CHECK-NEXT:    ld1h { z2.h, z10.h }, pn8/z, [x9]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx2], { z1.h, z2.h }, z0.h
+; CHECK-NEXT:    udot za.s[w8, 0, vgx2], { z9.h, z10.h }, z0.h
+; CHECK-NEXT:    ldr z10, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z9, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #3
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %1 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x2.nxv8i16(target("aarch64.svcount") %0, ptr %ptr)
+  %2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 0
+  %3 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 1
+  %arrayidx2 = getelementptr inbounds i8, ptr %ptr, i64 %stride
+  %4 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x2.nxv8i16(target("aarch64.svcount") %0, ptr %arrayidx2)
+  %5 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %4, 0
+  %6 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %4, 1
+  call void @llvm.aarch64.sme.udot.single.za32.vg1x2.nxv8i16(i32 0, <vscale x 8 x i16> %2, <vscale x 8 x i16> %5, <vscale x 8 x i16> %zn)
+  call void @llvm.aarch64.sme.udot.single.za32.vg1x2.nxv8i16(i32 0, <vscale x 8 x i16> %3, <vscale x 8 x i16> %6, <vscale x 8 x i16> %zn)
+  ret void
+}
+
 define void @udot_single_za32_u16_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3, <vscale x 8 x i16> %zn4) #0 {
 ; CHECK-LABEL: udot_single_za32_u16_vg1x4:
 ; CHECK:       // %bb.0:
@@ -332,6 +367,87 @@ define void @udot_single_za32_u16_vg1x4(i32 %slice, <vscale x 16 x i8> %unused,
   ret void
 }
 
+define void @udot_single_za32_u16_vg1x4_tuple(ptr %ptr, i64 %stride, <vscale x 8 x i16> %zn) #0 {
+; CHECK-LABEL: udot_single_za32_u16_vg1x4_tuple:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-11
+; CHECK-NEXT:    add x9, x1, x1, lsl #1
+; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    ptrue pn8.b
+; CHECK-NEXT:    str z20, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    add x10, x0, x1
+; CHECK-NEXT:    mov w8, wzr
+; CHECK-NEXT:    str z16, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    add x9, x0, x9
+; CHECK-NEXT:    str z15, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z14, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z13, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z12, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z11, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z10, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z9, [sp, #9, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z8, [sp, #10, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1h { z1.h, z5.h, z9.h, z13.h }, pn8/z, [x0]
+; CHECK-NEXT:    ld1h { z2.h, z6.h, z10.h, z14.h }, pn8/z, [x10]
+; CHECK-NEXT:    ld1h { z3.h, z7.h, z11.h, z15.h }, pn8/z, [x0, x1, lsl #1]
+; CHECK-NEXT:    ld1h { z16.h, z20.h, z24.h, z28.h }, pn8/z, [x9]
+; CHECK-NEXT:    mov z4.d, z16.d
+; CHECK-NEXT:    mov z8.d, z20.d
+; CHECK-NEXT:    mov z12.d, z24.d
+; CHECK-NEXT:    mov z16.d, z28.d
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z1.h - z4.h }, z0.h
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z5.h - z8.h }, z0.h
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z9.h - z12.h }, z0.h
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z13.h - z16.h }, z0.h
+; CHECK-NEXT:    ldr z20, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z16, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z15, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z14, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z13, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z12, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z11, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z9, [sp, #9, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z8, [sp, #10, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #11
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %1 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x4.nxv8i16(target("aarch64.svcount") %0, ptr %ptr)
+  %2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 0
+  %3 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 1
+  %4 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 2
+  %5 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 3
+  %arrayidx2 = getelementptr inbounds i8, ptr %ptr, i64 %stride
+  %6 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x4.nxv8i16(target("aarch64.svcount") %0, ptr %arrayidx2)
+  %7 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 0
+  %8 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 1
+  %9 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 2
+  %10 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 3
+  %mul3 = shl i64 %stride, 1
+  %arrayidx4 = getelementptr inbounds i8, ptr %ptr, i64 %mul3
+  %11 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x4.nxv8i16(target("aarch64.svcount") %0, ptr %arrayidx4)
+  %12 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %11, 0
+  %13 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %11, 1
+  %14 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %11, 2
+  %15 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %11, 3
+  %mul5 = mul i64 %stride, 3
+  %arrayidx6 = getelementptr inbounds i8, ptr %ptr, i64 %mul5
+  %16 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x4.nxv8i16(target("aarch64.svcount") %0, ptr %arrayidx6)
+  %17 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %16, 0
+  %18 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %16, 1
+  %19 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %16, 2
+  %20 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %16, 3
+  call void @llvm.aarch64.sme.udot.single.za32.vg1x4.nxv8i16(i32 0, <vscale x 8 x i16> %2, <vscale x 8 x i16> %7, <vscale x 8 x i16> %12, <vscale x 8 x i16> %17, <vscale x 8 x i16> %zn)
+  call void @llvm.aarch64.sme.udot.single.za32.vg1x4.nxv8i16(i32 0, <vscale x 8 x i16> %3, <vscale x 8 x i16> %8, <vscale x 8 x i16> %13, <vscale x 8 x i16> %18, <vscale x 8 x i16> %zn)
+  call void @llvm.aarch64.sme.udot.single.za32.vg1x4.nxv8i16(i32 0, <vscale x 8 x i16> %4, <vscale x 8 x i16> %9, <vscale x 8 x i16> %14, <vscale x 8 x i16> %19, <vscale x 8 x i16> %zn)
+  call void @llvm.aarch64.sme.udot.single.za32.vg1x4.nxv8i16(i32 0, <vscale x 8 x i16> %5, <vscale x 8 x i16> %10, <vscale x 8 x i16> %15, <vscale x 8 x i16> %20, <vscale x 8 x i16> %zn)
+  ret void
+}
+
 define void @udot_single_za32_u8_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2) #0 {
 ; CHECK-LABEL: udot_single_za32_u8_vg1x2:
 ; CHECK:       // %bb.0:
@@ -397,6 +513,40 @@ define void @usdot_single_za32_u8_vg1x2(i32 %slice, <vscale x 16 x i8> %unused,
   ret void
 }
 
+define void @usdot_single_za32_u16_vg1x2_tuple(ptr %ptr, i64 %stride, <vscale x 16 x i8> %zn) #0 {
+; CHECK-LABEL: usdot_sin...
[truncated]

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@@ -45,20 +45,26 @@ def am_sme_indexed_b4 : ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0, 15>
// If the operands do not match this pattern, the pseudos are expanded
// to a REG_SEQUENCE using the post-isel hook.

def FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO :
Pseudo<(outs ZPR2Mul2:$tup),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems simpler to just change the regclass of the exisitng pseudo to use the more generic ZPR2 (and ZPR4 for the _X4 pseudo). For instructions that require a ZPR2Mul2, a COPY will be introduced to move the ZPR2 to a ZPR2Mul2. Later on, the register allocator will coalesce this COPY, such that the resulting pseudo will have a ZPR2Mul2 regclass for the destination register.

This avoids the need to add new pseudos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sdesmalen-arm, that is much simpler!

@kmclaughlin-arm kmclaughlin-arm changed the title [AArch64][SME] Create separate FORM_TRANSPOSE pseudos for ZPR & ZPRMul classes [AArch64][SME] Change output class of FORM_TRANSPOSED_REG_TUPLE pseudos Jan 23, 2025
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kmclaughlin-arm kmclaughlin-arm merged commit 865104a into llvm:main Jan 24, 2025
8 checks passed
@kmclaughlin-arm kmclaughlin-arm deleted the form-transpose-zprmul branch February 6, 2025 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants