[llvm] [RISCV] Implement RISCVISD::SHL_ADD and move patterns into combine (PR #89263)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 18 12:17:13 PDT 2024


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/89263

>From dceb62aa366e93d19fd555e6b4b248d7cfcea8d3 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 18 Apr 2024 07:50:50 -0700
Subject: [PATCH 1/2] [RISCV] Implement RISCVISD::SHL_ADD and move patterns
 into combine

This implements a RISCV specific version of the SHL_ADD node proposed in
https://github.com/llvm/llvm-project/pull/88791.

If that lands, the infrastructure from this patch should seemlessly switch
over the to generic DAG node.  I'm posting this separately because I've run
out of useful multiply strength reduction work to do without having a way to
represent MUL X, 3/5/9 as a single instruction.

The majority of this change is moving two sets of patterns out of tablgen
and into the post-legalize combine.  The major reason for this is that I
have an upcoming change which needs to reuse the expansion logic, but it
also helps common up some code between zba and the THeadBa variants.

On the test changes, there's a couple major categories:
* We chose a different lowering for mul x, 25.  The new lowering involves
  one fewer register and the same critical path, so this seems like a win.
* The order of the two multiplies changes in (3,5,9)*(3,5,9) in some cases.
  I don't believe this matters.
* I'm removing the one use restriction on the multiply.  This restriction
  doesn't really make sense to me, and the test changes appear positve.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 57 ++++++++++-------
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  6 ++
 llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td | 26 +-------
 llvm/lib/Target/RISCV/RISCVInstrInfoZb.td     | 35 +++--------
 llvm/test/CodeGen/RISCV/addimm-mulimm.ll      | 20 +++---
 llvm/test/CodeGen/RISCV/rv32zba.ll            |  8 +--
 .../CodeGen/RISCV/rv64-legal-i32/rv64zba.ll   |  8 +--
 .../CodeGen/RISCV/rv64-legal-i32/xaluo.ll     | 18 +++---
 llvm/test/CodeGen/RISCV/rv64zba.ll            |  8 +--
 llvm/test/CodeGen/RISCV/xaluo.ll              | 61 +++++++++++--------
 10 files changed, 123 insertions(+), 124 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b0deb1d2669952..075f8c227f3dc8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13416,12 +13416,26 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     return SDValue();
   uint64_t MulAmt = CNode->getZExtValue();
 
-  // 3/5/9 * 2^N -> shXadd (sll X, C), (sll X, C)
-  // Matched in tablegen, avoid perturbing patterns.
-  for (uint64_t Divisor : {3, 5, 9})
-    if (MulAmt % Divisor == 0 && isPowerOf2_64(MulAmt / Divisor))
+  for (uint64_t Divisor : {3, 5, 9}) {
+    if (MulAmt % Divisor != 0)
+      continue;
+    uint64_t MulAmt2 = MulAmt / Divisor;
+    // 3/5/9 * 2^N -> shXadd (sll X, C), (sll X, C)
+    // Matched in tablegen, avoid perturbing patterns.
+    if (isPowerOf2_64(MulAmt2))
       return SDValue();
 
+    // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
+    if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
+      SDLoc DL(N);
+      SDValue Mul359 = DAG.getNode(
+          RISCVISD::SHL_ADD, DL, VT, N->getOperand(0),
+          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), N->getOperand(0));
+      return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+                         DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT), Mul359);
+    }
+  }
+
   // If this is a power 2 + 2/4/8, we can use a shift followed by a single
   // shXadd. First check if this a sum of two power of 2s because that's
   // easy. Then count how many zeros are up to the first bit.
@@ -13439,23 +13453,23 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
   }
 
   // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
-  // Matched in tablegen, avoid perturbing patterns.
-  switch (MulAmt) {
-  case 11:
-  case 13:
-  case 19:
-  case 21:
-  case 25:
-  case 27:
-  case 29:
-  case 37:
-  case 41:
-  case 45:
-  case 73:
-  case 91:
-    return SDValue();
-  default:
-    break;
+  // This is the two instruction form, there are also three instruction
+  // variants we could implement.  e.g.
+  //   (2^(1,2,3) * 3,5,9 + 1) << C2
+  //   2^(C1>3) * 3,5,9 +/- 1
+  for (uint64_t Divisor : {3, 5, 9}) {
+    uint64_t C = MulAmt - 1;
+    if (C <= Divisor)
+      continue;
+    unsigned TZ = llvm::countr_zero(C);
+    if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
+      SDLoc DL(N);
+      SDValue Mul359 = DAG.getNode(
+          RISCVISD::SHL_ADD, DL, VT, N->getOperand(0),
+          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), N->getOperand(0));
+      return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+                         DAG.getConstant(TZ, DL, VT), N->getOperand(0));
+    }
   }
 
   // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
@@ -19668,6 +19682,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(LLA)
   NODE_NAME_CASE(ADD_TPREL)
   NODE_NAME_CASE(MULHSU)
+  NODE_NAME_CASE(SHL_ADD)
   NODE_NAME_CASE(SLLW)
   NODE_NAME_CASE(SRAW)
   NODE_NAME_CASE(SRLW)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index b10da3d40befb7..ed14fd4539438a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -59,6 +59,12 @@ enum NodeType : unsigned {
 
   // Multiply high for signedxunsigned.
   MULHSU,
+
+  // Represents (ADD (SHL a, b), c) with the arguments appearing in the order
+  // a, b, c.  'b' must be a constant.  Maps to sh1add/sh2add/sh3add with zba
+  // or addsl with XTheadBa.
+  SHL_ADD,
+
   // RV64I shifts, directly matching the semantics of the named RISC-V
   // instructions.
   SLLW,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
index 79ced3864363b9..05dd09a267819f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
@@ -538,6 +538,8 @@ multiclass VPatTernaryVMAQA_VV_VX<string intrinsic, string instruction,
 let Predicates = [HasVendorXTHeadBa] in {
 def : Pat<(add (XLenVT GPR:$rs1), (shl GPR:$rs2, uimm2:$uimm2)),
           (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>;
+def : Pat<(XLenVT (riscv_shl_add GPR:$rs1, uimm2:$uimm2, GPR:$rs2)),
+          (TH_ADDSL GPR:$rs1, GPR:$rs2, uimm2:$uimm2)>;
 
 // Reuse complex patterns from StdExtZba
 def : Pat<(add_non_imm12 sh1add_op:$rs1, (XLenVT GPR:$rs2)),
@@ -581,30 +583,6 @@ def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
           (SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)),
                 (TrailingZeros C9LeftShift:$i))>;
 
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 1)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
-                    (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
-          (TH_ADDSL GPR:$r, (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
-          (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)), 3)>;
-
 def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
           (SLLI (XLenVT (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
                                   (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)), 3)>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index 434b071e628a0e..3d04415522919c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -26,6 +26,12 @@
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
 
+def SDTIntShiftAddOp : SDTypeProfile<1, 3, [   // shl_add
+  SDTCisSameAs<0, 1>, SDTCisSameAs<0, 3>, SDTCisInt<0>, SDTCisInt<2>,
+  SDTCisInt<3>
+]>;
+
+def riscv_shl_add    : SDNode<"RISCVISD::SHL_ADD"   , SDTIntShiftAddOp>;
 def riscv_clzw   : SDNode<"RISCVISD::CLZW",   SDT_RISCVIntUnaryOpW>;
 def riscv_ctzw   : SDNode<"RISCVISD::CTZW",   SDT_RISCVIntUnaryOpW>;
 def riscv_rolw   : SDNode<"RISCVISD::ROLW",   SDT_RISCVIntBinOpW>;
@@ -678,6 +684,8 @@ foreach i = {1,2,3} in {
   defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
   def : Pat<(XLenVT (add_like_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)),
             (shxadd GPR:$rs1, GPR:$rs2)>;
+  def : Pat<(XLenVT (riscv_shl_add GPR:$rs1, (XLenVT i), GPR:$rs2)),
+            (shxadd GPR:$rs1, GPR:$rs2)>;
 
   defvar pat = !cast<ComplexPattern>("sh"#i#"add_op");
   // More complex cases use a ComplexPattern.
@@ -721,31 +729,6 @@ def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
           (SLLI (XLenVT (SH3ADD GPR:$r, GPR:$r)),
                 (TrailingZeros C9LeftShift:$i))>;
 
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
-          (SH1ADD (XLenVT (SH2ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
-          (SH1ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
-          (SH2ADD (XLenVT (SH1ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
-          (SH2ADD (XLenVT (SH2ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
-          (SH2ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
-          (SH3ADD (XLenVT (SH1ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
-          (SH3ADD (XLenVT (SH2ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
-          (SH3ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)), GPR:$r)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
-          (SH1ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)),
-                  (XLenVT (SH3ADD GPR:$r, GPR:$r)))>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
-          (SH2ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)),
-                  (XLenVT (SH3ADD GPR:$r, GPR:$r)))>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
-          (SH3ADD (XLenVT (SH3ADD GPR:$r, GPR:$r)),
-                  (XLenVT (SH3ADD GPR:$r, GPR:$r)))>;
 } // Predicates = [HasStdExtZba]
 
 let Predicates = [HasStdExtZba, IsRV64] in {
@@ -881,6 +864,8 @@ foreach i = {1,2,3} in {
   defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
   def : Pat<(i32 (add_like_non_imm12 (shl GPR:$rs1, (i64 i)), GPR:$rs2)),
             (shxadd GPR:$rs1, GPR:$rs2)>;
+  def : Pat<(i32 (riscv_shl_add GPR:$rs1, (i32 i), GPR:$rs2)),
+            (shxadd GPR:$rs1, GPR:$rs2)>;
 }
 }
 
diff --git a/llvm/test/CodeGen/RISCV/addimm-mulimm.ll b/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
index 48fa69e1045656..736c8e7d55c75b 100644
--- a/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
+++ b/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
@@ -251,10 +251,12 @@ define i64 @add_mul_combine_reject_c3(i64 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_c3:
 ; RV32IMB:       # %bb.0:
 ; RV32IMB-NEXT:    li a2, 73
-; RV32IMB-NEXT:    mul a1, a1, a2
-; RV32IMB-NEXT:    mulhu a3, a0, a2
-; RV32IMB-NEXT:    add a1, a3, a1
-; RV32IMB-NEXT:    mul a2, a0, a2
+; RV32IMB-NEXT:    mulhu a2, a0, a2
+; RV32IMB-NEXT:    sh3add a3, a1, a1
+; RV32IMB-NEXT:    sh3add a1, a3, a1
+; RV32IMB-NEXT:    add a1, a2, a1
+; RV32IMB-NEXT:    sh3add a2, a0, a0
+; RV32IMB-NEXT:    sh3add a2, a2, a0
 ; RV32IMB-NEXT:    lui a0, 18
 ; RV32IMB-NEXT:    addi a0, a0, -728
 ; RV32IMB-NEXT:    add a0, a2, a0
@@ -518,10 +520,12 @@ define i64 @add_mul_combine_reject_g3(i64 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_g3:
 ; RV32IMB:       # %bb.0:
 ; RV32IMB-NEXT:    li a2, 73
-; RV32IMB-NEXT:    mul a1, a1, a2
-; RV32IMB-NEXT:    mulhu a3, a0, a2
-; RV32IMB-NEXT:    add a1, a3, a1
-; RV32IMB-NEXT:    mul a2, a0, a2
+; RV32IMB-NEXT:    mulhu a2, a0, a2
+; RV32IMB-NEXT:    sh3add a3, a1, a1
+; RV32IMB-NEXT:    sh3add a1, a3, a1
+; RV32IMB-NEXT:    add a1, a2, a1
+; RV32IMB-NEXT:    sh3add a2, a0, a0
+; RV32IMB-NEXT:    sh3add a2, a2, a0
 ; RV32IMB-NEXT:    lui a0, 2
 ; RV32IMB-NEXT:    addi a0, a0, -882
 ; RV32IMB-NEXT:    add a0, a2, a0
diff --git a/llvm/test/CodeGen/RISCV/rv32zba.ll b/llvm/test/CodeGen/RISCV/rv32zba.ll
index a78f823d318418..2a72c1288f65cc 100644
--- a/llvm/test/CodeGen/RISCV/rv32zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv32zba.ll
@@ -407,8 +407,8 @@ define i32 @mul25(i32 %a) {
 ;
 ; RV32ZBA-LABEL: mul25:
 ; RV32ZBA:       # %bb.0:
-; RV32ZBA-NEXT:    sh1add a1, a0, a0
-; RV32ZBA-NEXT:    sh3add a0, a1, a0
+; RV32ZBA-NEXT:    sh2add a0, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a0, a0
 ; RV32ZBA-NEXT:    ret
   %c = mul i32 %a, 25
   ret i32 %c
@@ -455,8 +455,8 @@ define i32 @mul27(i32 %a) {
 ;
 ; RV32ZBA-LABEL: mul27:
 ; RV32ZBA:       # %bb.0:
-; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    sh1add a0, a0, a0
+; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    ret
   %c = mul i32 %a, 27
   ret i32 %c
@@ -471,8 +471,8 @@ define i32 @mul45(i32 %a) {
 ;
 ; RV32ZBA-LABEL: mul45:
 ; RV32ZBA:       # %bb.0:
-; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    sh2add a0, a0, a0
+; RV32ZBA-NEXT:    sh3add a0, a0, a0
 ; RV32ZBA-NEXT:    ret
   %c = mul i32 %a, 45
   ret i32 %c
diff --git a/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
index ee9b73ca82f213..9f06a9dd124cef 100644
--- a/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
@@ -963,8 +963,8 @@ define i64 @mul25(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul25:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh1add a1, a0, a0
-; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 25
   ret i64 %c
@@ -1011,8 +1011,8 @@ define i64 @mul27(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul27:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh1add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 27
   ret i64 %c
@@ -1027,8 +1027,8 @@ define i64 @mul45(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul45:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 45
   ret i64 %c
diff --git a/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll b/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll
index 3c1b76818781a1..a1de326d16b536 100644
--- a/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-legal-i32/xaluo.ll
@@ -731,12 +731,13 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV64ZBA-LABEL: smulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
 ; RV64ZBA-NEXT:    li a2, 13
-; RV64ZBA-NEXT:    mulh a3, a0, a2
-; RV64ZBA-NEXT:    mul a2, a0, a2
-; RV64ZBA-NEXT:    srai a0, a2, 63
-; RV64ZBA-NEXT:    xor a0, a3, a0
+; RV64ZBA-NEXT:    mulh a2, a0, a2
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a3, a3, a0
+; RV64ZBA-NEXT:    srai a0, a3, 63
+; RV64ZBA-NEXT:    xor a0, a2, a0
 ; RV64ZBA-NEXT:    snez a0, a0
-; RV64ZBA-NEXT:    sd a2, 0(a1)
+; RV64ZBA-NEXT:    sd a3, 0(a1)
 ; RV64ZBA-NEXT:    ret
 ;
 ; RV64ZICOND-LABEL: smulo2.i64:
@@ -925,10 +926,11 @@ define zeroext i1 @umulo2.i64(i64 %v1, ptr %res) {
 ;
 ; RV64ZBA-LABEL: umulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    li a3, 13
-; RV64ZBA-NEXT:    mulhu a2, a0, a3
+; RV64ZBA-NEXT:    li a2, 13
+; RV64ZBA-NEXT:    mulhu a2, a0, a2
 ; RV64ZBA-NEXT:    snez a2, a2
-; RV64ZBA-NEXT:    mul a0, a0, a3
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a3, a0
 ; RV64ZBA-NEXT:    sd a0, 0(a1)
 ; RV64ZBA-NEXT:    mv a0, a2
 ; RV64ZBA-NEXT:    ret
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index ccb23fc2bbfa36..f31de84b8b047c 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1140,8 +1140,8 @@ define i64 @mul25(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul25:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh1add a1, a0, a0
-; RV64ZBA-NEXT:    sh3add a0, a1, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 25
   ret i64 %c
@@ -1188,8 +1188,8 @@ define i64 @mul27(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul27:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh1add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 27
   ret i64 %c
@@ -1204,8 +1204,8 @@ define i64 @mul45(i64 %a) {
 ;
 ; RV64ZBA-LABEL: mul45:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
 ; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 45
   ret i64 %c
diff --git a/llvm/test/CodeGen/RISCV/xaluo.ll b/llvm/test/CodeGen/RISCV/xaluo.ll
index ac67c0769f7056..1a88563c0ea2ed 100644
--- a/llvm/test/CodeGen/RISCV/xaluo.ll
+++ b/llvm/test/CodeGen/RISCV/xaluo.ll
@@ -1268,12 +1268,13 @@ define zeroext i1 @smulo2.i32(i32 signext %v1, ptr %res) {
 ; RV32ZBA-LABEL: smulo2.i32:
 ; RV32ZBA:       # %bb.0: # %entry
 ; RV32ZBA-NEXT:    li a2, 13
-; RV32ZBA-NEXT:    mulh a3, a0, a2
-; RV32ZBA-NEXT:    mul a2, a0, a2
-; RV32ZBA-NEXT:    srai a0, a2, 31
-; RV32ZBA-NEXT:    xor a0, a3, a0
+; RV32ZBA-NEXT:    mulh a2, a0, a2
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a3, a3, a0
+; RV32ZBA-NEXT:    srai a0, a3, 31
+; RV32ZBA-NEXT:    xor a0, a2, a0
 ; RV32ZBA-NEXT:    snez a0, a0
-; RV32ZBA-NEXT:    sw a2, 0(a1)
+; RV32ZBA-NEXT:    sw a3, 0(a1)
 ; RV32ZBA-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: smulo2.i32:
@@ -1577,13 +1578,15 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV32ZBA:       # %bb.0: # %entry
 ; RV32ZBA-NEXT:    li a3, 13
 ; RV32ZBA-NEXT:    mulhu a4, a0, a3
-; RV32ZBA-NEXT:    mul a5, a1, a3
+; RV32ZBA-NEXT:    sh1add a5, a1, a1
+; RV32ZBA-NEXT:    sh2add a5, a5, a1
 ; RV32ZBA-NEXT:    add a4, a5, a4
 ; RV32ZBA-NEXT:    sltu a5, a4, a5
 ; RV32ZBA-NEXT:    mulhu a6, a1, a3
 ; RV32ZBA-NEXT:    add a5, a6, a5
 ; RV32ZBA-NEXT:    srai a1, a1, 31
-; RV32ZBA-NEXT:    mul a6, a1, a3
+; RV32ZBA-NEXT:    sh1add a6, a1, a1
+; RV32ZBA-NEXT:    sh2add a6, a6, a1
 ; RV32ZBA-NEXT:    add a6, a5, a6
 ; RV32ZBA-NEXT:    srai a7, a4, 31
 ; RV32ZBA-NEXT:    xor t0, a6, a7
@@ -1593,7 +1596,8 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV32ZBA-NEXT:    xor a1, a1, a7
 ; RV32ZBA-NEXT:    or a1, t0, a1
 ; RV32ZBA-NEXT:    snez a1, a1
-; RV32ZBA-NEXT:    mul a0, a0, a3
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a3, a0
 ; RV32ZBA-NEXT:    sw a0, 0(a2)
 ; RV32ZBA-NEXT:    sw a4, 4(a2)
 ; RV32ZBA-NEXT:    mv a0, a1
@@ -1602,12 +1606,13 @@ define zeroext i1 @smulo2.i64(i64 %v1, ptr %res) {
 ; RV64ZBA-LABEL: smulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
 ; RV64ZBA-NEXT:    li a2, 13
-; RV64ZBA-NEXT:    mulh a3, a0, a2
-; RV64ZBA-NEXT:    mul a2, a0, a2
-; RV64ZBA-NEXT:    srai a0, a2, 63
-; RV64ZBA-NEXT:    xor a0, a3, a0
+; RV64ZBA-NEXT:    mulh a2, a0, a2
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a3, a3, a0
+; RV64ZBA-NEXT:    srai a0, a3, 63
+; RV64ZBA-NEXT:    xor a0, a2, a0
 ; RV64ZBA-NEXT:    snez a0, a0
-; RV64ZBA-NEXT:    sd a2, 0(a1)
+; RV64ZBA-NEXT:    sd a3, 0(a1)
 ; RV64ZBA-NEXT:    ret
 ;
 ; RV32ZICOND-LABEL: smulo2.i64:
@@ -1743,10 +1748,11 @@ define zeroext i1 @umulo2.i32(i32 signext %v1, ptr %res) {
 ;
 ; RV32ZBA-LABEL: umulo2.i32:
 ; RV32ZBA:       # %bb.0: # %entry
-; RV32ZBA-NEXT:    li a3, 13
-; RV32ZBA-NEXT:    mulhu a2, a0, a3
+; RV32ZBA-NEXT:    li a2, 13
+; RV32ZBA-NEXT:    mulhu a2, a0, a2
 ; RV32ZBA-NEXT:    snez a2, a2
-; RV32ZBA-NEXT:    mul a0, a0, a3
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a3, a0
 ; RV32ZBA-NEXT:    sw a0, 0(a1)
 ; RV32ZBA-NEXT:    mv a0, a2
 ; RV32ZBA-NEXT:    ret
@@ -1995,25 +2001,28 @@ define zeroext i1 @umulo2.i64(i64 %v1, ptr %res) {
 ; RV32ZBA-LABEL: umulo2.i64:
 ; RV32ZBA:       # %bb.0: # %entry
 ; RV32ZBA-NEXT:    li a3, 13
-; RV32ZBA-NEXT:    mul a4, a1, a3
-; RV32ZBA-NEXT:    mulhu a5, a0, a3
-; RV32ZBA-NEXT:    add a4, a5, a4
-; RV32ZBA-NEXT:    sltu a5, a4, a5
+; RV32ZBA-NEXT:    mulhu a4, a0, a3
+; RV32ZBA-NEXT:    sh1add a5, a1, a1
+; RV32ZBA-NEXT:    sh2add a5, a5, a1
+; RV32ZBA-NEXT:    add a5, a4, a5
+; RV32ZBA-NEXT:    sltu a4, a5, a4
 ; RV32ZBA-NEXT:    mulhu a1, a1, a3
 ; RV32ZBA-NEXT:    snez a1, a1
-; RV32ZBA-NEXT:    or a1, a1, a5
-; RV32ZBA-NEXT:    mul a0, a0, a3
+; RV32ZBA-NEXT:    or a1, a1, a4
+; RV32ZBA-NEXT:    sh1add a3, a0, a0
+; RV32ZBA-NEXT:    sh2add a0, a3, a0
 ; RV32ZBA-NEXT:    sw a0, 0(a2)
-; RV32ZBA-NEXT:    sw a4, 4(a2)
+; RV32ZBA-NEXT:    sw a5, 4(a2)
 ; RV32ZBA-NEXT:    mv a0, a1
 ; RV32ZBA-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: umulo2.i64:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    li a3, 13
-; RV64ZBA-NEXT:    mulhu a2, a0, a3
+; RV64ZBA-NEXT:    li a2, 13
+; RV64ZBA-NEXT:    mulhu a2, a0, a2
 ; RV64ZBA-NEXT:    snez a2, a2
-; RV64ZBA-NEXT:    mul a0, a0, a3
+; RV64ZBA-NEXT:    sh1add a3, a0, a0
+; RV64ZBA-NEXT:    sh2add a0, a3, a0
 ; RV64ZBA-NEXT:    sd a0, 0(a1)
 ; RV64ZBA-NEXT:    mv a0, a2
 ; RV64ZBA-NEXT:    ret

>From 6bae46294572c79600927f49a4a4d83d3ebd5f01 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 18 Apr 2024 12:16:36 -0700
Subject: [PATCH 2/2] Freeze operands when adding multiple uses

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 075f8c227f3dc8..1957091ebf8aa7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13428,9 +13428,10 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
     if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
       SDLoc DL(N);
-      SDValue Mul359 = DAG.getNode(
-          RISCVISD::SHL_ADD, DL, VT, N->getOperand(0),
-          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), N->getOperand(0));
+      SDValue X = DAG.getFreeze(N->getOperand(0));
+      SDValue Mul359 =
+          DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+                      DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
       return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
                          DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT), Mul359);
     }
@@ -13464,11 +13465,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     unsigned TZ = llvm::countr_zero(C);
     if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
       SDLoc DL(N);
-      SDValue Mul359 = DAG.getNode(
-          RISCVISD::SHL_ADD, DL, VT, N->getOperand(0),
-          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), N->getOperand(0));
+      SDValue X = DAG.getFreeze(N->getOperand(0));
+      SDValue Mul359 =
+          DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+                      DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
       return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
-                         DAG.getConstant(TZ, DL, VT), N->getOperand(0));
+                         DAG.getConstant(TZ, DL, VT), X);
     }
   }
 



More information about the llvm-commits mailing list