[llvm] [RISCV][NFC] Avoid iteration and division while selecting SHXADD instructions (PR #158851)

Piotr Fusik via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 16 02:01:09 PDT 2025


https://github.com/pfusik created https://github.com/llvm/llvm-project/pull/158851

Should improve compilation time.

>From 498acc94e3df8735be943b31e3123d7caea5449a Mon Sep 17 00:00:00 2001
From: Piotr Fusik <p.fusik at samsung.com>
Date: Tue, 16 Sep 2025 10:51:03 +0200
Subject: [PATCH] [RISCV][NFC] Avoid iteration and division while selecting
 SHXADD instructions

Should improve compilation time.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 134 +++++++++++---------
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp    |  29 ++---
 llvm/lib/Target/RISCV/RISCVInstrInfo.h      |  16 +++
 3 files changed, 101 insertions(+), 78 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9d90eb0a65218..fab9a7e962158 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16371,43 +16371,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
   SDValue X = N->getOperand(0);
 
   if (Subtarget.hasShlAdd(3)) {
-    for (uint64_t Divisor : {3, 5, 9}) {
-      if (MulAmt % Divisor != 0)
-        continue;
-      uint64_t MulAmt2 = MulAmt / Divisor;
-      // 3/5/9 * 2^N ->  shl (shXadd X, X), N
-      if (isPowerOf2_64(MulAmt2)) {
-        SDLoc DL(N);
-        SDValue X = N->getOperand(0);
-        // Put the shift first if we can fold a zext into the
-        // shift forming a slli.uw.
-        if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
-            X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
-          SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
-                                    DAG.getConstant(Log2_64(MulAmt2), DL, VT));
-          return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
-                             DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
-                             Shl);
-        }
-        // Otherwise, put rhe shl second so that it can fold with following
-        // instructions (e.g. sext or add).
-        SDValue Mul359 =
-            DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
-                        DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
-        return DAG.getNode(ISD::SHL, DL, VT, Mul359,
-                           DAG.getConstant(Log2_64(MulAmt2), DL, VT));
-      }
-
-      // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
-      if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
-        SDLoc DL(N);
-        SDValue Mul359 =
-            DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
-                        DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
-        return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
-                           DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
-                           Mul359);
+    int Shift;
+    if (int ShXAmount = isShifted359(MulAmt, Shift)) {
+      // 3/5/9 * 2^N -> shl (shXadd X, X), N
+      SDLoc DL(N);
+      SDValue X = N->getOperand(0);
+      // Put the shift first if we can fold a zext into the shift forming
+      // a slli.uw.
+      if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
+          X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
+        SDValue Shl =
+            DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
+        return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
+                           DAG.getConstant(ShXAmount, DL, VT), Shl);
       }
+      // Otherwise, put the shl second so that it can fold with following
+      // instructions (e.g. sext or add).
+      SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+                                   DAG.getConstant(ShXAmount, DL, VT), X);
+      return DAG.getNode(ISD::SHL, DL, VT, Mul359,
+                         DAG.getConstant(Shift, DL, VT));
+    }
+
+    // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
+    int ShX;
+    int ShY;
+    switch (MulAmt) {
+    case 3 * 5:
+      ShY = 1;
+      ShX = 2;
+      break;
+    case 3 * 9:
+      ShY = 1;
+      ShX = 3;
+      break;
+    case 5 * 5:
+      ShX = ShY = 2;
+      break;
+    case 5 * 9:
+      ShY = 2;
+      ShX = 3;
+      break;
+    case 9 * 9:
+      ShX = ShY = 3;
+      break;
+    default:
+      ShX = ShY = 0;
+      break;
+    }
+    if (ShX) {
+      SDLoc DL(N);
+      SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+                                   DAG.getConstant(ShY, DL, VT), X);
+      return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
+                         DAG.getConstant(ShX, DL, VT), Mul359);
     }
 
     // If this is a power 2 + 2/4/8, we can use a shift followed by a single
@@ -16430,18 +16447,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     // variants we could implement.  e.g.
     //   (2^(1,2,3) * 3,5,9 + 1) << C2
     //   2^(C1>3) * 3,5,9 +/- 1
-    for (uint64_t Divisor : {3, 5, 9}) {
-      uint64_t C = MulAmt - 1;
-      if (C <= Divisor)
-        continue;
-      unsigned TZ = llvm::countr_zero(C);
-      if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
+    if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) {
+      assert(Shift != 0 && "MulAmt=4,6,10 handled before");
+      if (Shift <= 3) {
         SDLoc DL(N);
-        SDValue Mul359 =
-            DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
-                        DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+        SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+                                     DAG.getConstant(ShXAmount, DL, VT), X);
         return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
-                           DAG.getConstant(TZ, DL, VT), X);
+                           DAG.getConstant(Shift, DL, VT), X);
       }
     }
 
@@ -16449,7 +16462,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
       unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
       if (ScaleShift >= 1 && ScaleShift < 4) {
-        unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
+        unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2));
         SDLoc DL(N);
         SDValue Shift1 =
             DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
@@ -16462,7 +16475,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
     for (uint64_t Offset : {3, 5, 9}) {
       if (isPowerOf2_64(MulAmt + Offset)) {
-        unsigned ShAmt = Log2_64(MulAmt + Offset);
+        unsigned ShAmt = llvm::countr_zero(MulAmt + Offset);
         if (ShAmt >= VT.getSizeInBits())
           continue;
         SDLoc DL(N);
@@ -16481,21 +16494,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
       uint64_t MulAmt2 = MulAmt / Divisor;
       // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
       // of 25 which happen to be quite common.
-      for (uint64_t Divisor2 : {3, 5, 9}) {
-        if (MulAmt2 % Divisor2 != 0)
-          continue;
-        uint64_t MulAmt3 = MulAmt2 / Divisor2;
-        if (isPowerOf2_64(MulAmt3)) {
-          SDLoc DL(N);
-          SDValue Mul359A =
-              DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
-                          DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
-          SDValue Mul359B = DAG.getNode(
-              RISCVISD::SHL_ADD, DL, VT, Mul359A,
-              DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A);
-          return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
-                             DAG.getConstant(Log2_64(MulAmt3), DL, VT));
-        }
+      if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
+        SDLoc DL(N);
+        SDValue Mul359A =
+            DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+                        DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+        SDValue Mul359B =
+            DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
+                        DAG.getConstant(ShBAmount, DL, VT), Mul359A);
+        return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
+                           DAG.getConstant(Shift, DL, VT));
       }
     }
   }
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index f816112f70140..794ec5f6cc3dd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -4492,24 +4492,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
         .addReg(DestReg, RegState::Kill)
         .addImm(ShiftAmount)
         .setMIFlag(Flag);
-  } else if (STI.hasShlAdd(3) &&
-             ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
-              (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
-              (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
+  } else if (int ShXAmount, ShiftAmount;
+             STI.hasShlAdd(3) &&
+             (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) {
     // We can use Zba SHXADD+SLLI instructions for multiply in some cases.
     unsigned Opc;
-    uint32_t ShiftAmount;
-    if (Amount % 9 == 0) {
-      Opc = RISCV::SH3ADD;
-      ShiftAmount = Log2_64(Amount / 9);
-    } else if (Amount % 5 == 0) {
-      Opc = RISCV::SH2ADD;
-      ShiftAmount = Log2_64(Amount / 5);
-    } else if (Amount % 3 == 0) {
+    switch (ShXAmount) {
+    case 1:
       Opc = RISCV::SH1ADD;
-      ShiftAmount = Log2_64(Amount / 3);
-    } else {
-      llvm_unreachable("implied by if-clause");
+      break;
+    case 2:
+      Opc = RISCV::SH2ADD;
+      break;
+    case 3:
+      Opc = RISCV::SH3ADD;
+      break;
+    default:
+      llvm_unreachable("unexpected result of isShifted359");
     }
     if (ShiftAmount)
       BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 57ec431749ebe..fe3f1bfd5e2a1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -25,6 +25,22 @@
 
 namespace llvm {
 
+template <typename T> int isShifted359(T Value, int &Shift) {
+  if (Value == 0)
+    return 0;
+  Shift = llvm::countr_zero(Value);
+  switch (Value >> Shift) {
+  case 3:
+    return 1;
+  case 5:
+    return 2;
+  case 9:
+    return 3;
+  default:
+    return 0;
+  }
+}
+
 class RISCVSubtarget;
 
 static const MachineMemOperand::Flags MONontemporalBit0 =



More information about the llvm-commits mailing list