[llvm] [RISCV] Use FSHR in LowerShiftRightParts for P extension on RV64. (PR #181234)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 12 12:47:55 PST 2026


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/181234

We can't do the NSRLI trick on RV64, but we can use srx similar to what we do in LowerShiftLeftParts. We need an additional fixup step for the FSHR result that NSRLI doesn't need.

>From 63455c0c43ec82d0ccb41a8cdf50a15f628ca388 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 12 Feb 2026 11:58:51 -0800
Subject: [PATCH] [RISCV] Use FSHR in LowerShiftRightParts for P extension on
 RV64.

We can't do the NSRLI trick on RV64, but we can use srx similar
to what we do in LowerShiftLeftParts. We need an additional fixup
step for the FSHR result that NSRLI doesn't need.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 54 +++++++++++++++------
 llvm/test/CodeGen/RISCV/rv64p.ll            | 40 +++++----------
 2 files changed, 51 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c12426458c3a5..ffd15483b246e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -10328,31 +10328,57 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   SDValue Shamt = Op.getOperand(2);
   EVT VT = Lo.getValueType();
 
-  // With P extension on RV32, use NSRL/NSRA for the low part.
-  if (Subtarget.hasStdExtP() && !Subtarget.is64Bit()) {
-    SDValue LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, VT,
-                                Lo, Hi, Shamt);
-    // Mask shift amount to avoid UB when Shamt >= 32.
+  // With P extension, use NSRL/NSRA for RV32 or FSHR (SRX) for RV64.
+  if (Subtarget.hasStdExtP()) {
+    unsigned XLen = Subtarget.getXLen();
+
+    SDValue LoRes;
+    if (Subtarget.is64Bit()) {
+      // On RV64, use FSHR (SRX instruction) for the low part. We will need
+      // to fix this later if ShAmt >= 64.
+      LoRes = DAG.getNode(ISD::FSHR, DL, VT, Hi, Lo, Shamt);
+    } else {
+      // On RV32, use NSRL/NSRA for the low part.
+      // NSRL/NSRA read 6 bits of shift amount, so they handle Shamt >= 32
+      // correctly.
+      LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, VT, Lo,
+                          Hi, Shamt);
+    }
+
+    // Mask shift amount to avoid UB when Shamt >= XLen.
     SDValue ShamtMasked =
-        DAG.getNode(ISD::AND, DL, VT, Shamt, DAG.getConstant(31, DL, VT));
+        DAG.getNode(ISD::AND, DL, VT, Shamt, DAG.getConstant(XLen - 1, DL, VT));
     SDValue HiRes =
         DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked);
 
-    // Create a mask that is -1 when Shamt >= 32, 0 otherwise.
+    // Create a mask that is -1 when Shamt >= XLen, 0 otherwise.
     // FIXME: We should use a select and let LowerSelect make the
     // optimizations.
     SDValue ShAmtExt =
-        DAG.getNode(ISD::SHL, DL, VT, Shamt, DAG.getConstant(26, DL, VT));
-    SDValue Mask =
-        DAG.getNode(ISD::SRA, DL, VT, ShAmtExt, DAG.getConstant(31, DL, VT));
+        DAG.getNode(ISD::SHL, DL, VT, Shamt,
+                    DAG.getConstant(XLen - Log2_32(XLen) - 1, DL, VT));
+    SDValue Mask = DAG.getNode(ISD::SRA, DL, VT, ShAmtExt,
+                               DAG.getConstant(XLen - 1, DL, VT));
+
+    if (Subtarget.is64Bit()) {
+      // On RV64, FSHR masks shift amount to 63. We need to replace LoRes
+      // with HiRes when Shamt >= 64.
+      // LoRes = (LoRes & ~Mask) | (HiRes & Mask)
+      SDValue LoMasked =
+          DAG.getNode(ISD::AND, DL, VT, LoRes, DAG.getNOT(DL, Mask, VT));
+      SDValue HiMasked = DAG.getNode(ISD::AND, DL, VT, HiRes, Mask);
+      LoRes = DAG.getNode(ISD::OR, DL, VT, LoMasked, HiMasked,
+                          SDNodeFlags::Disjoint);
+    }
 
+    // If ShAmt >= XLen, we need to replace HiRes with 0 or sign bits.
     if (IsSRA) {
-      // sra hi, hi, (mask & 31) - shifts by 31 when shamt >= 32
-      SDValue MaskAmt =
-          DAG.getNode(ISD::AND, DL, VT, Mask, DAG.getConstant(31, DL, VT));
+      // sra hi, hi, (mask & (XLen-1)) - shifts by XLen-1 when shamt >= XLen
+      SDValue MaskAmt = DAG.getNode(ISD::AND, DL, VT, Mask,
+                                    DAG.getConstant(XLen - 1, DL, VT));
       HiRes = DAG.getNode(ISD::SRA, DL, VT, HiRes, MaskAmt);
     } else {
-      // andn hi, hi, mask - clears hi when shamt >= 32
+      // andn hi, hi, mask - clears hi when shamt >= XLen
       HiRes = DAG.getNode(ISD::AND, DL, VT, HiRes, DAG.getNOT(DL, Mask, VT));
     }
 
diff --git a/llvm/test/CodeGen/RISCV/rv64p.ll b/llvm/test/CodeGen/RISCV/rv64p.ll
index 670022a537e00..747a676b134fa 100644
--- a/llvm/test/CodeGen/RISCV/rv64p.ll
+++ b/llvm/test/CodeGen/RISCV/rv64p.ll
@@ -391,21 +391,12 @@ define i128 @slli_i128_large(i128 %x) {
 define i128 @srl_i128(i128 %x, i128 %y) {
 ; CHECK-LABEL: srl_i128:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a4, a2, -64
 ; CHECK-NEXT:    srl a3, a1, a2
-; CHECK-NEXT:    bltz a4, .LBB32_2
-; CHECK-NEXT:  # %bb.1:
-; CHECK-NEXT:    mv a0, a3
-; CHECK-NEXT:    j .LBB32_3
-; CHECK-NEXT:  .LBB32_2:
-; CHECK-NEXT:    srl a0, a0, a2
-; CHECK-NEXT:    not a2, a2
-; CHECK-NEXT:    slli a1, a1, 1
-; CHECK-NEXT:    sll a1, a1, a2
-; CHECK-NEXT:    or a0, a0, a1
-; CHECK-NEXT:  .LBB32_3:
-; CHECK-NEXT:    srai a1, a4, 63
-; CHECK-NEXT:    and a1, a1, a3
+; CHECK-NEXT:    srx a0, a1, a2
+; CHECK-NEXT:    slli a2, a2, 57
+; CHECK-NEXT:    srai a2, a2, 63
+; CHECK-NEXT:    mvm a0, a3, a2
+; CHECK-NEXT:    andn a1, a3, a2
 ; CHECK-NEXT:    ret
   %b = lshr i128 %x, %y
   ret i128 %b
@@ -461,21 +452,12 @@ define i128 @srli_i128_large(i128 %x) {
 define i128 @sra_i128(i128 %x, i128 %y) {
 ; CHECK-LABEL: sra_i128:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    mv a3, a1
-; CHECK-NEXT:    addi a4, a2, -64
-; CHECK-NEXT:    sra a1, a1, a2
-; CHECK-NEXT:    bltz a4, .LBB37_2
-; CHECK-NEXT:  # %bb.1:
-; CHECK-NEXT:    srai a3, a3, 63
-; CHECK-NEXT:    mv a0, a1
-; CHECK-NEXT:    mv a1, a3
-; CHECK-NEXT:    ret
-; CHECK-NEXT:  .LBB37_2:
-; CHECK-NEXT:    srl a0, a0, a2
-; CHECK-NEXT:    not a2, a2
-; CHECK-NEXT:    slli a3, a3, 1
-; CHECK-NEXT:    sll a2, a3, a2
-; CHECK-NEXT:    or a0, a0, a2
+; CHECK-NEXT:    sra a3, a1, a2
+; CHECK-NEXT:    srx a0, a1, a2
+; CHECK-NEXT:    slli a2, a2, 57
+; CHECK-NEXT:    srai a2, a2, 63
+; CHECK-NEXT:    mvm a0, a3, a2
+; CHECK-NEXT:    sra a1, a3, a2
 ; CHECK-NEXT:    ret
   %b = ashr i128 %x, %y
   ret i128 %b



More information about the llvm-commits mailing list