[llvm] [RISCV] Strength reduce mul by 2^n + 2/4/8 + 1 (PR #88911)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 16 10:01:01 PDT 2024


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/88911

>From 160eef9c2593370723e0dc75d0ea8b505a6e69d6 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 12 Apr 2024 08:44:40 -0700
Subject: [PATCH 1/3] [RISCV] Strength reduce mul by 2^n + 2/4/8 + 1

With zba, we can expand this to (add (shl X, C1), (shXadd X, X)).

Note that this is our first expansion to a three instruction sequence.
I believe this to general be a reasonable tradeoff for most architectures,
but we may want to (someday) consider a tuning flag here.

I plan to support 2^n + (2/4/8 + 1) eventually as well, but that comes
behind 2^N - 2^M.  Both are also three instruction sequences.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 28 +++++++++++
 llvm/test/CodeGen/RISCV/rv64zba.ll          | 51 +++++++++++++++------
 2 files changed, 64 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 765838aafb58d2..388a67a0575f63 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13431,6 +13431,34 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
       return DAG.getNode(ISD::ADD, DL, VT, Shift1, Shift2);
     }
   }
+
+  // (2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
+  // Matched in tablegen, avoid perturbing patterns.
+  switch (MulAmt) {
+  case 11: case 13: case 19: case 21: case 25: case 27: case 29:
+  case 37: case 41: case 45: case 73: case 91:
+    return SDValue();
+  default:
+    break;
+  }
+
+  // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
+  if (MulAmt > 2 &&
+      isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
+    unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
+    if (ScaleShift >= 1 && ScaleShift < 4) {
+      unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
+      SDLoc DL(N);
+      SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+                                   DAG.getConstant(ShiftAmt, DL, VT));
+      SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
+                                   DAG.getConstant(ScaleShift, DL, VT));
+      return DAG.getNode(ISD::ADD, DL, VT, Shift1,
+                         DAG.getNode(ISD::ADD, DL, VT, Shift2,
+                                     N->getOperand(0)));
+    }
+  }
+
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index a84b9e5e7962f6..d452c9d1b0f0c5 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -598,31 +598,52 @@ define i64 @mul125(i64 %a) {
 }
 
 define i64 @mul131(i64 %a) {
-; CHECK-LABEL: mul131:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, 131
-; CHECK-NEXT:    mul a0, a0, a1
-; CHECK-NEXT:    ret
+; RV64I-LABEL: mul131:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    li a1, 131
+; RV64I-NEXT:    mul a0, a0, a1
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: mul131:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh1add a1, a0, a0
+; RV64ZBA-NEXT:    slli a0, a0, 7
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 131
   ret i64 %c
 }
 
 define i64 @mul133(i64 %a) {
-; CHECK-LABEL: mul133:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, 133
-; CHECK-NEXT:    mul a0, a0, a1
-; CHECK-NEXT:    ret
+; RV64I-LABEL: mul133:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    li a1, 133
+; RV64I-NEXT:    mul a0, a0, a1
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: mul133:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh2add a1, a0, a0
+; RV64ZBA-NEXT:    slli a0, a0, 7
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 133
   ret i64 %c
 }
 
 define i64 @mul137(i64 %a) {
-; CHECK-LABEL: mul137:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, 137
-; CHECK-NEXT:    mul a0, a0, a1
-; CHECK-NEXT:    ret
+; RV64I-LABEL: mul137:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    li a1, 137
+; RV64I-NEXT:    mul a0, a0, a1
+; RV64I-NEXT:    ret
+;
+; RV64ZBA-LABEL: mul137:
+; RV64ZBA:       # %bb.0:
+; RV64ZBA-NEXT:    sh3add a1, a0, a0
+; RV64ZBA-NEXT:    slli a0, a0, 7
+; RV64ZBA-NEXT:    add a0, a0, a1
+; RV64ZBA-NEXT:    ret
   %c = mul i64 %a, 137
   ret i64 %c
 }

>From ee56852fe3437902742a54b0dadb6163b3e30809 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 16 Apr 2024 09:54:48 -0700
Subject: [PATCH 2/3] Update llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Co-authored-by: Min-Yih Hsu <min at myhsu.dev>
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 388a67a0575f63..19da066e069a1f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13432,7 +13432,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  // (2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
+  // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
   // Matched in tablegen, avoid perturbing patterns.
   switch (MulAmt) {
   case 11: case 13: case 19: case 21: case 25: case 27: case 29:

>From 38b3b51ef0c062cd574a485ebaa7b8e01ddb3bb3 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 16 Apr 2024 09:59:21 -0700
Subject: [PATCH 3/3] clang-format

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 ++++++++++++++-------
 1 file changed, 16 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 19da066e069a1f..743ce71275e29a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13435,16 +13435,25 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
   // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
   // Matched in tablegen, avoid perturbing patterns.
   switch (MulAmt) {
-  case 11: case 13: case 19: case 21: case 25: case 27: case 29:
-  case 37: case 41: case 45: case 73: case 91:
+  case 11:
+  case 13:
+  case 19:
+  case 21:
+  case 25:
+  case 27:
+  case 29:
+  case 37:
+  case 41:
+  case 45:
+  case 73:
+  case 91:
     return SDValue();
   default:
     break;
   }
 
   // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
-  if (MulAmt > 2 &&
-      isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
+  if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
     unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
     if (ScaleShift >= 1 && ScaleShift < 4) {
       unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
@@ -13453,9 +13462,9 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
                                    DAG.getConstant(ShiftAmt, DL, VT));
       SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
                                    DAG.getConstant(ScaleShift, DL, VT));
-      return DAG.getNode(ISD::ADD, DL, VT, Shift1,
-                         DAG.getNode(ISD::ADD, DL, VT, Shift2,
-                                     N->getOperand(0)));
+      return DAG.getNode(
+          ISD::ADD, DL, VT, Shift1,
+          DAG.getNode(ISD::ADD, DL, VT, Shift2, N->getOperand(0)));
     }
   }
 



More information about the llvm-commits mailing list