[llvm] [RISCV] Expand `X * (2^N - 2^M)` where `N < M` (PR #168843)

Piotr Fusik via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 20 02:01:14 PST 2025


https://github.com/pfusik created https://github.com/llvm/llvm-project/pull/168843

None

>From 4005d46040bb0479fad31f3a2def0a3ed8bb5bb3 Mon Sep 17 00:00:00 2001
From: Piotr Fusik <p.fusik at samsung.com>
Date: Thu, 20 Nov 2025 10:41:42 +0100
Subject: [PATCH 1/3] [RISCV] Expand `X * (2^N - 2^M)` where `N < M`

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  4 ++++
 llvm/test/CodeGen/RISCV/mul.ll              | 14 +++++++++-----
 llvm/test/CodeGen/RISCV/rv64xtheadba.ll     |  5 +++--
 llvm/test/CodeGen/RISCV/rv64zba.ll          |  5 +++--
 llvm/test/CodeGen/RISCV/srem-vector-lkk.ll  | 21 +++++++++++----------
 5 files changed, 30 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6020fb6ca16ce..68bf3db7b7f1f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16874,6 +16874,10 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG,
   } else if (CanSub) {
     Op = ISD::SUB;
     ShiftAmt1 = MulAmt + MulAmtLowBit;
+  } else if (isPowerOf2_64(MulAmtLowBit - MulAmt)) {
+    Op = ISD::SUB;
+    ShiftAmt1 = MulAmtLowBit;
+    MulAmtLowBit -= MulAmt;
   } else {
     return SDValue();
   }
diff --git a/llvm/test/CodeGen/RISCV/mul.ll b/llvm/test/CodeGen/RISCV/mul.ll
index 4533e14c672e7..940fe598fc0f0 100644
--- a/llvm/test/CodeGen/RISCV/mul.ll
+++ b/llvm/test/CodeGen/RISCV/mul.ll
@@ -1679,13 +1679,17 @@ define i128 @muli128_m3840(i128 %a) nounwind {
 ;
 ; RV64IM-LABEL: muli128_m3840:
 ; RV64IM:       # %bb.0:
+; RV64IM-NEXT:    slli a2, a1, 12
+; RV64IM-NEXT:    slli a1, a1, 8
+; RV64IM-NEXT:    sub a1, a1, a2
 ; RV64IM-NEXT:    li a2, -15
 ; RV64IM-NEXT:    slli a2, a2, 8
-; RV64IM-NEXT:    mul a1, a1, a2
-; RV64IM-NEXT:    mulhu a3, a0, a2
-; RV64IM-NEXT:    sub a3, a3, a0
-; RV64IM-NEXT:    add a1, a3, a1
-; RV64IM-NEXT:    mul a0, a0, a2
+; RV64IM-NEXT:    mulhu a2, a0, a2
+; RV64IM-NEXT:    sub a1, a0, a1
+; RV64IM-NEXT:    sub a1, a2, a1
+; RV64IM-NEXT:    slli a2, a0, 12
+; RV64IM-NEXT:    slli a0, a0, 8
+; RV64IM-NEXT:    sub a0, a0, a2
 ; RV64IM-NEXT:    ret
   %1 = mul i128 %a, -3840
   ret i128 %1
diff --git a/llvm/test/CodeGen/RISCV/rv64xtheadba.ll b/llvm/test/CodeGen/RISCV/rv64xtheadba.ll
index c57dfca1389b6..b4de214250b22 100644
--- a/llvm/test/CodeGen/RISCV/rv64xtheadba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64xtheadba.ll
@@ -2024,8 +2024,9 @@ define i64 @mul_neg5(i64 %a) {
 define i64 @mul_neg6(i64 %a) {
 ; CHECK-LABEL: mul_neg6:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, -6
-; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 3
+; CHECK-NEXT:    slli a0, a0, 1
+; CHECK-NEXT:    sub a0, a0, a1
 ; CHECK-NEXT:    ret
   %c = mul i64 %a, -6
   ret i64 %c
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index fb26b8b16a290..73a886f6d7c60 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -4157,8 +4157,9 @@ define i64 @mul_neg5(i64 %a) {
 define i64 @mul_neg6(i64 %a) {
 ; CHECK-LABEL: mul_neg6:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    li a1, -6
-; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    slli a1, a0, 3
+; CHECK-NEXT:    slli a0, a0, 1
+; CHECK-NEXT:    sub a0, a0, a1
 ; CHECK-NEXT:    ret
   %c = mul i64 %a, -6
   ret i64 %c
diff --git a/llvm/test/CodeGen/RISCV/srem-vector-lkk.ll b/llvm/test/CodeGen/RISCV/srem-vector-lkk.ll
index 7548885f8405b..0a1946b5978f5 100644
--- a/llvm/test/CodeGen/RISCV/srem-vector-lkk.ll
+++ b/llvm/test/CodeGen/RISCV/srem-vector-lkk.ll
@@ -145,8 +145,8 @@ define <4 x i16> @fold_srem_vec_1(<4 x i16> %x) nounwind {
 ;
 ; RV64IM-LABEL: fold_srem_vec_1:
 ; RV64IM:       # %bb.0:
-; RV64IM-NEXT:    lh a3, 0(a1)
-; RV64IM-NEXT:    lh a2, 8(a1)
+; RV64IM-NEXT:    lh a2, 0(a1)
+; RV64IM-NEXT:    lh a3, 8(a1)
 ; RV64IM-NEXT:    lh a4, 16(a1)
 ; RV64IM-NEXT:    lh a1, 24(a1)
 ; RV64IM-NEXT:    lui a5, %hi(.LCPI0_0)
@@ -161,8 +161,8 @@ define <4 x i16> @fold_srem_vec_1(<4 x i16> %x) nounwind {
 ; RV64IM-NEXT:    mulh a6, a2, a6
 ; RV64IM-NEXT:    mulh a7, a4, a7
 ; RV64IM-NEXT:    mulh t0, a1, t0
-; RV64IM-NEXT:    add a5, a5, a3
-; RV64IM-NEXT:    sub a6, a6, a2
+; RV64IM-NEXT:    sub a5, a5, a3
+; RV64IM-NEXT:    add a6, a6, a2
 ; RV64IM-NEXT:    srli t1, a7, 63
 ; RV64IM-NEXT:    srli a7, a7, 5
 ; RV64IM-NEXT:    add a7, a7, t1
@@ -170,7 +170,7 @@ define <4 x i16> @fold_srem_vec_1(<4 x i16> %x) nounwind {
 ; RV64IM-NEXT:    srli t0, t0, 7
 ; RV64IM-NEXT:    add t0, t0, t1
 ; RV64IM-NEXT:    srli t1, a5, 63
-; RV64IM-NEXT:    srli a5, a5, 6
+; RV64IM-NEXT:    srai a5, a5, 6
 ; RV64IM-NEXT:    add a5, a5, t1
 ; RV64IM-NEXT:    srli t1, a6, 63
 ; RV64IM-NEXT:    srli a6, a6, 6
@@ -180,15 +180,16 @@ define <4 x i16> @fold_srem_vec_1(<4 x i16> %x) nounwind {
 ; RV64IM-NEXT:    li t1, -1003
 ; RV64IM-NEXT:    mul t0, t0, t1
 ; RV64IM-NEXT:    li t1, 95
-; RV64IM-NEXT:    mul a5, a5, t1
-; RV64IM-NEXT:    li t1, -124
 ; RV64IM-NEXT:    mul a6, a6, t1
 ; RV64IM-NEXT:    sub a4, a4, a7
 ; RV64IM-NEXT:    sub a1, a1, t0
-; RV64IM-NEXT:    sub a3, a3, a5
+; RV64IM-NEXT:    slli a7, a5, 2
+; RV64IM-NEXT:    slli a5, a5, 7
+; RV64IM-NEXT:    sub a5, a5, a7
 ; RV64IM-NEXT:    sub a2, a2, a6
-; RV64IM-NEXT:    sh a3, 0(a0)
-; RV64IM-NEXT:    sh a2, 2(a0)
+; RV64IM-NEXT:    add a3, a3, a5
+; RV64IM-NEXT:    sh a2, 0(a0)
+; RV64IM-NEXT:    sh a3, 2(a0)
 ; RV64IM-NEXT:    sh a4, 4(a0)
 ; RV64IM-NEXT:    sh a1, 6(a0)
 ; RV64IM-NEXT:    ret

>From 301db0fa701d29e5d64f16acdeb0525ce319744b Mon Sep 17 00:00:00 2001
From: Piotr Fusik <p.fusik at samsung.com>
Date: Thu, 20 Nov 2025 10:53:32 +0100
Subject: [PATCH 2/3] [RISCV][NFC] Rename variables

In the new case, MulAmtLowBit is the higher bit (2^M).
ShiftAmt1 was an isolated bit and not a shift amount.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 26 ++++++++++-----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68bf3db7b7f1f..2741286ce73f5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16854,39 +16854,39 @@ static SDValue expandMulToNAFSequence(SDNode *N, SelectionDAG &DAG,
 // X * (2^N +/- 2^M) -> (add/sub (shl X, C1), (shl X, C2))
 static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG,
                                         uint64_t MulAmt) {
-  uint64_t MulAmtLowBit = MulAmt & (-MulAmt);
+  uint64_t MulAmtBit2 = MulAmt & (-MulAmt); // lowest set bit
   SDValue X = N->getOperand(0);
   ISD::NodeType Op;
-  uint64_t ShiftAmt1;
-  bool CanSub = isPowerOf2_64(MulAmt + MulAmtLowBit);
-  auto PreferSub = [X, MulAmtLowBit]() {
+  uint64_t MulAmtBit1;
+  bool CanSub = isPowerOf2_64(MulAmt + MulAmtBit2);
+  auto PreferSub = [X, MulAmtBit2]() {
     // For MulAmt == 3 << M both (X << M + 2) - (X << M)
     // and (X << M + 1) + (X << M) are valid expansions.
     // Prefer SUB if we can get (X << M + 2) for free,
     // because X is exact (Y >> M + 2).
-    uint64_t ShAmt = Log2_64(MulAmtLowBit) + 2;
+    uint64_t ShAmt = Log2_64(MulAmtBit2) + 2;
     using namespace SDPatternMatch;
     return sd_match(X, m_ExactSr(m_Value(), m_SpecificInt(ShAmt)));
   };
-  if (isPowerOf2_64(MulAmt - MulAmtLowBit) && !(CanSub && PreferSub())) {
+  if (isPowerOf2_64(MulAmt - MulAmtBit2) && !(CanSub && PreferSub())) {
     Op = ISD::ADD;
-    ShiftAmt1 = MulAmt - MulAmtLowBit;
+    MulAmtBit1 = MulAmt - MulAmtBit2;
   } else if (CanSub) {
     Op = ISD::SUB;
-    ShiftAmt1 = MulAmt + MulAmtLowBit;
-  } else if (isPowerOf2_64(MulAmtLowBit - MulAmt)) {
+    MulAmtBit1 = MulAmt + MulAmtBit2;
+  } else if (isPowerOf2_64(MulAmtBit2 - MulAmt)) {
     Op = ISD::SUB;
-    ShiftAmt1 = MulAmtLowBit;
-    MulAmtLowBit -= MulAmt;
+    MulAmtBit1 = MulAmtBit2;
+    MulAmtBit2 -= MulAmt;
   } else {
     return SDValue();
   }
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
   SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, X,
-                               DAG.getConstant(Log2_64(ShiftAmt1), DL, VT));
+                               DAG.getConstant(Log2_64(MulAmtBit1), DL, VT));
   SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, X,
-                               DAG.getConstant(Log2_64(MulAmtLowBit), DL, VT));
+                               DAG.getConstant(Log2_64(MulAmtBit2), DL, VT));
   return DAG.getNode(Op, DL, VT, Shift1, Shift2);
 }
 

>From 003255ffd90f8b50ad83b49dd7e0a6b8e3da06dc Mon Sep 17 00:00:00 2001
From: Piotr Fusik <p.fusik at samsung.com>
Date: Thu, 20 Nov 2025 10:56:12 +0100
Subject: [PATCH 3/3] [RISCV][NFC] Comment the two SUB cases

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2741286ce73f5..3a0755d6af31d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16872,9 +16872,11 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG,
     Op = ISD::ADD;
     MulAmtBit1 = MulAmt - MulAmtBit2;
   } else if (CanSub) {
+    // N > M
     Op = ISD::SUB;
     MulAmtBit1 = MulAmt + MulAmtBit2;
   } else if (isPowerOf2_64(MulAmtBit2 - MulAmt)) {
+    // N < M
     Op = ISD::SUB;
     MulAmtBit1 = MulAmtBit2;
     MulAmtBit2 -= MulAmt;



More information about the llvm-commits mailing list