Skip to content

[RISCV] Implement RISCVISD::SHL_ADD and move patterns into combine #89263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 22, 2024

Conversation

preames
Copy link
Collaborator

@preames preames commented Apr 18, 2024

This implements a RISCV specific version of the SHL_ADD node proposed in #88791.

If that lands, the infrastructure from this patch should seamlessly switch over the to generic DAG node. I'm posting this separately because I've run out of useful multiply strength reduction work to do without having a way to represent MUL X, 3/5/9 as a single instruction.

The majority of this change is moving two sets of patterns out of tablgen and into the post-legalize combine. The major reason for this is that I have an upcoming change which needs to reuse the expansion logic, but it also helps common up some code between zba and the THeadBa variants.

On the test changes, there's a couple major categories:

  • We chose a different lowering for mul x, 25. The new lowering involves one fewer register and the same critical path, so this seems like a win.
  • The order of the two multiplies changes in (3,5,9)*(3,5,9) in some cases. I don't believe this matters.
  • I'm removing the one use restriction on the multiply. This restriction doesn't really make sense to me, and the test changes appear positive.

This implements a RISCV specific version of the SHL_ADD node proposed in
llvm#88791.

If that lands, the infrastructure from this patch should seemlessly switch
over the to generic DAG node.  I'm posting this separately because I've run
out of useful multiply strength reduction work to do without having a way to
represent MUL X, 3/5/9 as a single instruction.

The majority of this change is moving two sets of patterns out of tablgen
and into the post-legalize combine.  The major reason for this is that I
have an upcoming change which needs to reuse the expansion logic, but it
also helps common up some code between zba and the THeadBa variants.

On the test changes, there's a couple major categories:
* We chose a different lowering for mul x, 25.  The new lowering involves
  one fewer register and the same critical path, so this seems like a win.
* The order of the two multiplies changes in (3,5,9)*(3,5,9) in some cases.
  I don't believe this matters.
* I'm removing the one use restriction on the multiply.  This restriction
  doesn't really make sense to me, and the test changes appear positve.
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

This implements a RISCV specific version of the SHL_ADD node proposed in #88791.

If that lands, the infrastructure from this patch should seamlessly switch over the to generic DAG node. I'm posting this separately because I've run out of useful multiply strength reduction work to do without having a way to represent MUL X, 3/5/9 as a single instruction.

The majority of this change is moving two sets of patterns out of tablgen and into the post-legalize combine. The major reason for this is that I have an upcoming change which needs to reuse the expansion logic, but it also helps common up some code between zba and the THeadBa variants.

On the test changes, there's a couple major categories:

  • We chose a different lowering for mul x, 25. The new lowering involves one fewer register and the same critical path, so this seems like a win.
  • The order of the two multiplies changes in (3,5,9)*(3,5,9) in some cases. I don't believe this matters.
  • I'm removing the one use restriction on the multiply. This restriction doesn't really make sense to me, and the test changes appear positive.

Full diff: https://github.com/llvm/llvm-project/pull/89263.diff

10 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+36-21)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+6)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td (+2-24)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZb.td (+10-25)
  • (modified) llvm/test/CodeGen/RISCV/addimm-mulimm.ll (+12-8)
  • (modified) llvm/test/CodeGen/RISCV/rv32zba.ll (+4-4)
  • (modified) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll (+4-4)
  • (modified) llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll (+10-8)
  • (modified) llvm/test/CodeGen/RISCV/rv64zba.ll (+4-4)
  • (modified) llvm/test/CodeGen/RISCV/xaluo.ll (+35-26)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b0deb1d2669952..075f8c227f3dc8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13416,12 +13416,26 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     return SDValue();
   uint64_t MulAmt = CNode->getZExtValue();
 
-  // 3/5/9 * 2^N -> shXadd (sll X, C), (sll X, C)
-  // Matched in tablegen, avoid perturbing patterns.
-  for (uint64_t Divisor : {3, 5, 9})
-    if (MulAmt % Divisor == 0 && isPowerOf2_64(MulAmt / Divisor))
+  for (uint64_t Divisor : {3, 5, 9}) {
+    if (MulAmt % Divisor != 0)
+      continue;
+    uint64_t MulAmt2 = MulAmt / Divisor;
+    // 3/5/9 * 2^N -> shXadd (sll X, C), (sll X, C)
+    // Matched in tablegen, avoid perturbing patterns.
+    if (isPowerOf2_64(MulAmt2))
       return SDValue();
 
+    // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
+    if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
+      SDLoc DL(N);
+      SDValue Mul359 = DAG.getNode(
+          RISCVISD::SHL_ADD, DL, VT, N->getOperand(0),
+          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), N->getOperand(0));
+      return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+                         DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT), Mul359);
+    }
+  }
+
   // If this is a power 2 + 2/4/8, we can use a shift followed by a single
   // shXadd. First check if this a sum of two power of 2s because that's
   // easy. Then count how many zeros are up to the first bit.
@@ -13439,23 +13453,23 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
   }
 
   // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
-  // Matched in tablegen, avoid perturbing patterns.
-  switch (MulAmt) {
-  case 11:
-  case 13:
-  case 19:
-  case 21:
-  case 25:
-  case 27:
-  case 29:
-  case 37:
-  case 41:
-  case 45:
-  case 73:
-  case 91:
-    return SDValue();
-  default:
-    break;
+  // This is the two instruction form, there are also three instruction
+  // variants we could implement.  e.g.
+  //   (2^(1,2,3) * 3,5,9 + 1) << C2
+  //   2^(C1>3) * 3,5,9 +/- 1
+  for (uint64_t Divisor : {3, 5, 9}) {
+    uint64_t C = MulAmt - 1;
+    if (C <= Divisor)
+      continue;
+    unsigned TZ = llvm::countr_zero(C);
+    if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
+      SDLoc DL(N);
+      SDValue Mul359 = DAG.getNode(
+          RISCVISD::SHL_ADD, DL, VT, N->getOperand(0),
+          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), N->getOperand(0));
+      return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+                         DAG.getConstant(TZ, DL, VT), N->getOperand(0));
+    }
   }
 
   // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
@@ -19668,6 +19682,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(LLA)
   NODE_NAME_CASE(ADD_TPREL)
   NODE_NAME_CASE(MULHSU)
+  NODE_NAME_CASE(SHL_ADD)
   NODE_NAME_CASE(SLLW)
   NODE_NAME_CASE(SRAW)
   NODE_NAME_CASE(SRLW)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index b10da3d40befb7..ed14fd4539438a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -59,6 +59,12 @@ enum NodeType : unsigned {
 
   // Multiply high for signedxunsigned.
   MULHSU,
+
+  // Represents (ADD (SHL a, b), c) with the arguments appearing in the order
+  // a, b, c.  'b' must be a constant.  Maps to sh1add/sh2add/sh3add with zba
+  // or addsl with XTheadBa.
+  SHL_ADD,
+
   // RV64I shifts, directly matching the semantics of the named RISC-V
   // instructions.
   SLLW,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
index 79ced3864363b9..05dd09a267819f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
@@ -538,6 +538,8 @@ multiclass VPatTernaryVMAQA_VV_VX<string intrinsic, string instruction,
 let Predicates = [HasVendorXTHeadBa] in {
 def : Pat<(add (XLenVT GPR:$rs1), (shl GPR:$rs2, uimm2:$uimm2)),
           (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>;
+def : Pat<(XLenVT (riscv_shl_add GPR:$rs1, uimm2:$uimm2, GPR:$rs2)),
+          (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>;
 
 // Reuse complex patterns from StdExtZba
 def : Pat<(add_non_imm12 sh1add_op:$rs1, (XLenVT GPR:$rs2)),
@@ -581,30 +583,6 @@ def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
           (SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)),
                 (TrailingZeros C9LeftShift:$i))>;
 
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 1)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
-                    (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 3)>;
-
 def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
           (SLLI (XLenVT (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
                                   (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)), 3)>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index 434b071e628a0e..3d04415522919c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -26,6 +26,12 @@
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
 
+def SDTIntShiftAddOp : SDTypeProfile<1, 3, [   // shl_add
+  SDTCisSameAs<0, 1>, SDTCisSameAs<0, 3>, SDTCisInt<0>, SDTCisInt<2>,
+  SDTCisInt<3>
+]>;
+
+def riscv_shl_add    : SDNode<"RISCVISD::SHL_ADD"   , SDTIntShiftAddOp>;
 def riscv_clzw   : SDNode<"RISCVISD::CLZW",   SDT_RISCVIntUnaryOpW>;
 def riscv_ctzw   : SDNode<"RISCVISD::CTZW",   SDT_RISCVIntUnaryOpW>;
 def riscv_rolw   : SDNode<"RISCVISD::ROLW",   SDT_RISCVIntBinOpW>;
@@ -678,6 +684,8 @@ foreach i = {1,2,3} in {
   defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
   def : Pat<(XLenVT (add_like_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)),
             (shxadd GPR:$rs1, GPR:$rs2)>;
+  def : Pat<(XLenVT (riscv_shl_add GPR:$rs1, (XLenVT i), GPR:$rs2)),
+            (shxadd GPR:$rs1, GPR:$rs2)>;
 
   defvar pat = !cast<ComplexPattern>("sh"#i#"add_op");
   // More complex cases use a ComplexPattern.
@@ -721,31 +729,6 @@ def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
           (SLLI (XLenVT (SH3ADD GPR:$r, GPR:$r)),
                 (TrailingZeros C9LeftShift:$i))>;
 
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
-          (SH1ADD (XLenVT (SH2ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
-          (SH1ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
-          (SH2ADD (XLenVT (SH1ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
-          (SH2ADD (XLenVT (SH2ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
-          (SH2ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
-          (SH3ADD (XLenVT (SH1ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
-          (SH3ADD (XLenVT (SH2ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
-          (SH3ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
-          (SH1ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)),
-                  (XLenVT (SH3ADD GPR:$r, GPR:$r)))>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
-          (SH2ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)),
-                  (XLenVT (SH3ADD GPR:$r, GPR:$r)))>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
-          (SH3ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)),
-                  (XLenVT (SH3ADD GPR:$r, GPR:$r)))>;
 } // Predicates = [HasStdExtZba]
 
 let Predicates = [HasStdExtZba, IsRV64] in {
@@ -881,6 +864,8 @@ foreach i = {1,2,3} in {
   defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
   def : Pat<(i32 (add_like_non_imm12 (shl GPR:$rs1, (i64 i)), GPR:$rs2)),
             (shxadd GPR:$rs1, GPR:$rs2)>;
+  def : Pat<(i32 (riscv_shl_add GPR:$rs1, (i32 i), GPR:$rs2)),
+            (shxadd GPR:$rs1, GPR:$rs2)>;
 }
 }
 
diff --git a/llvm/test/CodeGen/RISCV/addimm-mulimm.ll b/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
index 48fa69e1045656..736c8e7d55c75b 100644
--- a/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
+++ b/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
@@ -251,10 +251,12 @@ define i64 @add_mul_combine_reject_c3(i64 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_c3:
 ; RV32IMB:       # %bb.0:
 ; RV32IMB-NEXT:    li a2, 73
-; RV32IMB-NEXT:    mul a1, a1, a2
-; RV32IMB-NEXT:    mulhu a3, a0, a2
-; RV32IMB-NEXT:    add a1, a3, a1
-; RV32IMB-NEXT:    mul a2, a0, a2
+; RV32IMB-NEXT:    mulhu a2, a0, a2
+; RV32IMB-NEXT:    sh3add a3, a1, a1
+; RV32IMB-NEXT:    sh3add a1, a3, a1
+; RV32IMB-NEXT:    add a1, a2, a1
+; RV32IMB-NEXT:    sh3add a2, a0, a0
+; RV32IMB-NEXT:    sh3add a2, a2, a0
 ; RV32IMB-NEXT:    lui a0, 18
 ; RV32IMB-NEXT:    addi a0, a0, -728
 ; RV32IMB-NEXT:    add a0, a2, a0
@@ -518,10 +520,12 @@ define i64 @add_mul_combine_reject_g3(i64 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_g3:
 ; RV32IMB:       # %bb.0:
 ; RV32IMB-NEXT:    li a2, 73
-; RV32IMB-NEXT:    mul a1, a1, a2
-; RV32IMB-NEXT:    mulhu a3, a0, a2
-; RV32IMB-NEXT:    add a1, a3, a1
-; RV32IMB-NEXT:    mul a2, a0, a2
+; RV32IMB-NEXT:    mulhu a2, a0, a2
+; RV32IMB-NEXT:    sh3add a3, a1, a1
+; RV32IMB-NEXT:    sh3add a1, a3, a1
+; RV32IMB-NEXT:    add a1, a2, a1
+; RV32IMB-NEXT:    sh3add a2, a0, a0
+; RV32IMB-NEXT:    sh3add a2, a2, a0
 ; RV32IMB-NEXT:    lui a0, 2
 ; RV32IMB-NEXT:    addi a0, a0, -882
 ; RV32IMB-NEXT:    add a0, a2, a0
diff --git a/llvm/test/CodeGen/RISCV/rv32zba.ll b/llvm/test/CodeGen/RISCV/rv32zba.ll
index a78f823d318418..2a72c1288f65cc 100644
--- a/llvm/test/CodeGen/RISCV/rv32zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv32zba.ll
@@ -407,8 +407,8 @@ define i32 @mul25(i32 %a) {
 ;
 ; RV32ZBA-LABEL: mul25:
 ; RV32ZBA:       # %bb.0:
-; RV32ZBA-NEXT:    sh1add a1, a0, a0
-; RV32ZBA-NEXT:    sh3add a0, a1, a0
+; RV32ZBA-NEXT:    sh2add a0, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a0, a0
 ; RV32ZBA-NEXT:    ret
   %c = mul i32 %a, 25
   ret i32 %c
@@ -455,8 +455,8 @@ define i32 @mul27(i32 %a) {
 ;
 ; RV32ZBA-LABEL: mul27:
 ; RV32ZBA:       # %bb.0:
-; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    sh1add a0, a0, a0
+; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    ret
   %c = mul i32 %a, 27
   ret i32 %c
@@ -471,8 +471,8 @@ define i32 @mul45(i32 %a) {
 ;
 ; RV32ZBA-LABEL: mul45:
 ; RV32ZBA:       # %bb.0:
-; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    sh2add a0, a0, a0
+; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    ret
   %c = mul i32 %a, 45
   ret i32 %c
diff --git a/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
index ee9b73ca82f213..9f06a9dd124cef 100644
--- a/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
@@ -963,8 +963,8 @@ define i64 @mul25(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul25:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh1add a1, a0, a0
-; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 25
   ret i64 %c
@@ -1011,8 +1011,8 @@ define i64 @mul27(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul27:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh1add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 27
   ret i64 %c
@@ -1027,8 +1027,8 @@ define i64 @mul45(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul45:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 45
   ret i64 %c
diff --git a/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll b/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll
index 3c1b76818781a1..a1de326d16b536 100644
--- a/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll
@@ -731,12 +731,13 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV64ZBA-LABEL: smulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
 ; RV64ZBA-NEXT:    li a2, 13
-; RV64ZBA-NEXT:    mulh a3, a0, a2
-; RV64ZBA-NEXT:    mul a2, a0, a2
-; RV64ZBA-NEXT:    srai a0, a2, 63
-; RV64ZBA-NEXT:    xor a0, a3, a0
+; RV64ZBA-NEXT:    mulh a2, a0, a2
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a3, a3, a0
+; RV64ZBA-NEXT:    srai a0, a3, 63
+; RV64ZBA-NEXT:    xor a0, a2, a0
 ; RV64ZBA-NEXT:    snez a0, a0
-; RV64ZBA-NEXT:    sd a2, 0(a1)
+; RV64ZBA-NEXT:    sd a3, 0(a1)
 ; RV64ZBA-NEXT:    ret
 ;
 ; RV64ZICOND-LABEL: smulo2.i64:
@@ -925,10 +926,11 @@ define zeroext i1 @umulo2.i64(i64 %v1, ptr %res) {
 ;
 ; RV64ZBA-LABEL: umulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    li a3, 13
-; RV64ZBA-NEXT:    mulhu a2, a0, a3
+; RV64ZBA-NEXT:    li a2, 13
+; RV64ZBA-NEXT:    mulhu a2, a0, a2
 ; RV64ZBA-NEXT:    snez a2, a2
-; RV64ZBA-NEXT:    mul a0, a0, a3
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a3, a0
 ; RV64ZBA-NEXT:    sd a0, 0(a1)
 ; RV64ZBA-NEXT:    mv a0, a2
 ; RV64ZBA-NEXT:    ret
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index ccb23fc2bbfa36..f31de84b8b047c 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1140,8 +1140,8 @@ define i64 @mul25(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul25:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh1add a1, a0, a0
-; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 25
   ret i64 %c
@@ -1188,8 +1188,8 @@ define i64 @mul27(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul27:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh1add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 27
   ret i64 %c
@@ -1204,8 +1204,8 @@ define i64 @mul45(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul45:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 45
   ret i64 %c
diff --git a/llvm/test/CodeGen/RISCV/xaluo.ll b/llvm/test/CodeGen/RISCV/xaluo.ll
index ac67c0769f7056..1a88563c0ea2ed 100644
--- a/llvm/test/CodeGen/RISCV/xaluo.ll
+++ b/llvm/test/CodeGen/RISCV/xaluo.ll
@@ -1268,12 +1268,13 @@ define zeroext i1 @smulo2.i32(i32 signext %v1, ptr %res) {
 ; RV32ZBA-LABEL: smulo2.i32:
 ; RV32ZBA:       # %bb.0: # %entry
 ; RV32ZBA-NEXT:    li a2, 13
-; RV32ZBA-NEXT:    mulh a3, a0, a2
-; RV32ZBA-NEXT:    mul a2, a0, a2
-; RV32ZBA-NEXT:    srai a0, a2, 31
-; RV32ZBA-NEXT:    xor a0, a3, a0
+; RV32ZBA-NEXT:    mulh a2, a0, a2
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a3, a3, a0
+; RV32ZBA-NEXT:    srai a0, a3, 31
+; RV32ZBA-NEXT:    xor a0, a2, a0
 ; RV32ZBA-NEXT:    snez a0, a0
-; RV32ZBA-NEXT:    sw a2, 0(a1)
+; RV32ZBA-NEXT:    sw a3, 0(a1)
 ; RV32ZBA-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: smulo2.i32:
@@ -1577,13 +1578,15 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV32ZBA:       # %bb.0: # %entry
 ; RV32ZBA-NEXT:    li a3, 13
 ; RV32ZBA-NEXT:    mulhu a4, a0, a3
-; RV32ZBA-NEXT:    mul a5, a1, a3
+; RV32ZBA-NEXT:    sh1add a5, a1, a1
+; RV32ZBA-NEXT:    sh2add a5, a5, a1
 ; RV32ZBA-NEXT:    add a4, a5, a4
 ; RV32ZBA-NEXT:    sltu a5, a4, a5
 ; RV32ZBA-NEXT:    mulhu a6, a1, a3
 ; RV32ZBA-NEXT:    add a5, a6, a5
 ; RV32ZBA-NEXT:    srai a1, a1, 31
-; RV32ZBA-NEXT:    mul a6, a1, a3
+; RV32ZBA-NEXT:    sh1add a6, a1, a1
+; RV32ZBA-NEXT:    sh2add a6, a6, a1
 ; RV32ZBA-NEXT:    add a6, a5, a6
 ; RV32ZBA-NEXT:    srai a7, a4, 31
 ; RV32ZBA-NEXT:    xor t0, a6, a7
@@ -1593,7 +1596,8 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV32ZBA-NEXT:    xor a1, a1, a7
 ; RV32ZBA-NEXT:    or a1, t0, a1
 ; RV32ZBA-NEXT:    snez a1, a1
-; RV32ZBA-NEXT:    mul a0, a0, a3
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a3, a0
 ; RV32ZBA-NEXT:    sw a0, 0(a2)
 ; RV32ZBA-NEXT:    sw a4, 4(a2)
 ; RV32ZBA-NEXT:    mv a0, a1
@@ -1602,12 +1606,13 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV64ZBA-LABEL: smulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
 ; RV64ZBA-NEXT:    li a2, 13
-; RV64ZBA-NEXT:    mulh a3, a0, a2
-; RV64ZBA-NEXT:    mul a2, a0, a2
-; RV64ZBA-NEXT:    srai a0, a2, 63
-; RV64ZBA-NEXT:    xor a0, a3, a0
+; RV64ZBA-NEXT:    mulh a2, a0, a2
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a3, a3, a0
+; RV64ZBA-NEXT:    srai a0, a3, 63
+; RV64ZBA-NEXT:    xor a0, a2, a0
 ; RV64ZBA-NEXT:    snez a0, a0
-; RV64ZBA-NEXT:    sd a2, 0(a1)
+; RV64ZBA-NEXT:    sd a3, 0(a1)
 ; RV64ZBA-NEXT:    ret
 ;
 ; RV32ZICOND-LABEL: smulo2.i64:
@@ -1743,10 +1748,11 @@ define zeroext i1 @umulo2.i32(i32 signext %v1, ptr %res) {
 ;
 ; RV32ZBA-LABEL: umulo2.i32:
 ; RV32ZBA:       # %bb.0: # %entry
-; RV32ZBA-NEXT:    li a3, 13
-; RV32ZBA-NEXT:    mulhu a2, a0, a3
+; RV32ZBA-NEXT:    li a2, 13
+; RV32ZBA-NEXT:    mulhu a2, a0, a2
 ; RV32ZBA-NEXT:    snez a2, a2
-; RV32ZBA-NEXT:    mul a0, a0, a3
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a3, a0
 ; RV32ZBA-NEXT:    sw a0, 0(a1)
 ; RV32ZBA-NEXT:    mv a0, a2
 ; RV32ZBA-NEXT:    ret
@@ -1995,25 +2001,28 @@ define zeroext i1 @umulo2.i64(i64 %v1, ptr %res) {
 ; RV32ZBA-LABEL: umulo2.i64:
 ; RV32ZBA:       # %bb.0: # %entry
 ; RV32ZBA-NEXT:    li a3, 13
-; RV32ZBA-NEXT:    mul a4, a1, a3
-; RV32ZBA-NEXT:    mulhu a5, a0, a3
-; RV32ZBA-NEXT:    add a4, a5, a4
-; RV32ZBA-NEXT:    sltu a5, a4, a5
+; RV32ZBA-NEXT:    mulhu a4, a0, a3
+; RV32ZBA-NEXT:    sh1add a5, a1, a1
+; RV32ZBA-NEXT:    sh2add a5, a5, a1
+; RV32ZBA-NEXT:    add a5, a4, a5
+; RV32ZBA-NEXT:    sltu a4, a5, a4
 ; RV32ZBA-NEXT:    mulhu a1, a1, a3
 ; RV32ZBA-NEXT:    snez a1, a1
-; RV32ZBA-NEXT:    or a1, a1, a5
-; RV32ZBA-NEXT:    mul a0, a0, a3
+; RV32ZBA-NEXT:    or a1, a1, a4
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a3, a0
 ; RV32ZBA-NEXT:    sw a0, 0(a2)
-; RV32ZBA-NEXT:    sw a4, 4(a2)
+; RV32ZBA-NEXT:    sw a5, 4(a2)
 ; RV32ZBA-NEXT:    mv a0, a1
 ; RV32ZBA-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: umulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    li a3, 13
-; RV64ZBA-NEXT:    mulhu a2, a0, a3
+; RV64ZBA-NEXT:    li a2, 13
+; RV64ZBA-NEXT:    mulhu a2, a0, a2
 ; RV64ZBA-NEXT:    snez a2, a2
-; RV64ZBA-NEXT:    mul a0, a0, a3
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a3, a0
 ; RV64ZBA-NEXT:    sd a0, 0(a1)
 ; RV64ZBA-NEXT:    mv a0, a2
 ; RV64ZBA-NEXT:    ret

@wangpc-pp wangpc-pp requested review from wangpc-pp and removed request for pcwang-thead April 19, 2024 04:15
@wangpc-pp
Copy link
Contributor

Sorry for bothering, but I have to say that this account (@pcwang-thead) is not used and I can't access it now. Hopefully its access can be revoked after six months (https://discourse.llvm.org/t/rfc2-new-criteria-for-commit-access/77110), or @tstellar can just remove it.

dtcxzyw added a commit to dtcxzyw/llvm-codegen-benchmark that referenced this pull request Apr 19, 2024
(TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 1)>;
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
(TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 2)>;
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

81 is not in the switch/case in the code above, is it removed accidentally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The switch has "case 91:" which appears to have been a typo.

@preames
Copy link
Collaborator Author

preames commented Apr 22, 2024

ping?

SDTCisInt<3>
]>;

def riscv_shl_add : SDNode<"RISCVISD::SHL_ADD" , SDTIntShiftAddOp>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting is weird on this line. I assume it was copied from below but the spacing wasn't updated.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 5a7c80c into llvm:main Apr 22, 2024
3 of 4 checks passed
@preames preames deleted the pr-riscv-shl_add branch April 22, 2024 20:41
preames added a commit that referenced this pull request Apr 23, 2024
…mbine (#89263)"

This reverts commit 5a7c80c.  Noticed failures
with the following command:
$ llc -mtriple=riscv64 -mattr=+m,+xtheadba -verify-machineinstrs < test/CodeGen/RISCV/rv64zba.ll

I think I know the cause and will likely reland with a fix tomorrow.
preames added a commit that referenced this pull request Apr 23, 2024
…ombine (#89263)"

Changes since original commit:
* Rebase over improved test coverage for theadba
* Revert change to use TargetConstant as it appears to prevent the uimm2
  clause from matching in the XTheadBa patterns.
* Fix an order of operands bug in the THeadBa pattern visible in the new
  test coverage.

Original commit message follows:

This implements a RISCV specific version of the SHL_ADD node proposed in
#88791.

If that lands, the infrastructure from this patch should seamlessly
switch over the to generic DAG node. I'm posting this separately because
I've run out of useful multiply strength reduction work to do without
having a way to represent MUL X, 3/5/9 as a single instruction.

The majority of this change is moving two sets of patterns out of
tablgen and into the post-legalize combine. The major reason for this is
that I have an upcoming change which needs to reuse the expansion logic,
but it also helps common up some code between zba and the THeadBa
variants.

On the test changes, there's a couple major categories:
* We chose a different lowering for mul x, 25. The new lowering involves
  one fewer register and the same critical path, so this seems like a win.
* The order of the two multiplies changes in (3,5,9)*(3,5,9) in some
  cases. I don't believe this matters.
* I'm removing the one use restriction on the multiply. This restriction
  doesn't really make sense to me, and the test changes appear positive.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants