[llvm] [RISCV][P-ext] Custom legalize i64 SHL to WSLL(I)/WSLA(I) (PR #185079)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 6 10:47:19 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

When input is zero or sign extended.

---
Full diff: https://github.com/llvm/llvm-project/pull/185079.diff


4 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+29) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+25-5) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+7) 
- (modified) llvm/test/CodeGen/RISCV/rv32p.ll (+40) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 3492e60662380..8e4913218c1df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1871,6 +1871,35 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     CurDAG->RemoveDeadNode(Node);
     return;
   }
+  case RISCVISD::WSLL:
+  case RISCVISD::WSLA: {
+    // Custom select (S/U)MUL_LOHI to WMUL(U) for RV32P.
+    assert(Subtarget->hasStdExtP() && !Subtarget->is64Bit() && VT == MVT::i32 &&
+           "Unexpected opcode");
+
+    bool IsSigned = Node->getOpcode() == RISCVISD::WSLA;
+
+    SDValue ShAmt = Node->getOperand(1);
+
+    unsigned Opc;
+
+    auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt);
+    if (ShAmtC && ShAmtC->getZExtValue() < 64) {
+      Opc = IsSigned ? RISCV::WSLAI : RISCV::WSLLI;
+      ShAmt = CurDAG->getTargetConstant(ShAmtC->getZExtValue(), DL, XLenVT);
+    } else {
+      Opc = IsSigned ? RISCV::WSLA : RISCV::WSLL;
+    }
+
+    SDNode *WShift = CurDAG->getMachineNode(Opc, DL, MVT::Untyped,
+                                            Node->getOperand(0), ShAmt);
+
+    auto [Lo, Hi] = extractGPRPair(CurDAG, DL, SDValue(WShift, 0));
+    ReplaceUses(SDValue(Node, 0), Lo);
+    ReplaceUses(SDValue(Node, 1), Hi);
+    CurDAG->RemoveDeadNode(Node);
+    return;
+  }
   case ISD::LOAD: {
     if (tryIndexedLoad(Node))
       return;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 891bc22a7463d..29ff12aa96efd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15418,6 +15418,26 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
       assert(!Subtarget.is64Bit() && Subtarget.hasStdExtP() &&
              "Unexpected custom legalisation");
 
+      SDValue LHS = N->getOperand(0);
+      SDValue ShAmt = N->getOperand(1);
+
+      unsigned WideOpc = 0;
+      APInt HighMask = APInt::getHighBitsSet(64, 32);
+      if (DAG.MaskedValueIsZero(LHS, HighMask))
+        WideOpc = RISCVISD::WSLL;
+      else if (DAG.ComputeMaxSignificantBits(LHS) <= 32)
+        WideOpc = RISCVISD::WSLA;
+
+      if (WideOpc) {
+        SDValue Res =
+            DAG.getNode(WideOpc, DL, DAG.getVTList(MVT::i32, MVT::i32),
+                        DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, LHS),
+                        DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, ShAmt));
+        Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0),
+                                      Res, Res.getValue(1)));
+        return;
+      }
+
       // Only handle constant shifts < 32. Non-constant shifts are handled by
       // lowerShiftLeftParts/lowerShiftRightParts, and shifts >= 32 use default
       // legalization.
@@ -15425,22 +15445,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
       if (!ShAmtC || ShAmtC->getZExtValue() >= 32)
         break;
 
-      auto [Lo, Hi] = DAG.SplitScalar(N->getOperand(0), DL, MVT::i32, MVT::i32);
+      auto [Lo, Hi] = DAG.SplitScalar(LHS, DL, MVT::i32, MVT::i32);
 
       SDValue LoRes, HiRes;
       if (N->getOpcode() == ISD::SHL) {
         // Lo = slli Lo, shamt
         // Hi = nsrli {Hi, Lo}, (32 - shamt)
         uint64_t ShAmtVal = ShAmtC->getZExtValue();
-        LoRes = DAG.getNode(ISD::SHL, DL, MVT::i32, Lo, N->getOperand(1));
+        LoRes = DAG.getNode(ISD::SHL, DL, MVT::i32, Lo, ShAmt);
         HiRes = DAG.getNode(RISCVISD::NSRL, DL, MVT::i32, Lo, Hi,
                             DAG.getConstant(32 - ShAmtVal, DL, MVT::i32));
       } else {
         bool IsSRA = N->getOpcode() == ISD::SRA;
         LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL,
-                            MVT::i32, Lo, Hi, N->getOperand(1));
-        HiRes = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi,
-                            N->getOperand(1));
+                            MVT::i32, Lo, Hi, ShAmt);
+        HiRes =
+            DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi, ShAmt);
       }
       SDValue Res = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, LoRes, HiRes);
       Results.push_back(Res);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 7bb9ad5feb219..ea16cf28bfd7c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1633,6 +1633,13 @@ def riscv_wsubau : RVSDNode<"WSUBAU", SDT_RISCVWideningAddSubAccumulate>;
 
 def riscv_wmulsu : RVSDNode<"WMULSU", SDTIntBinHiLoOp>;
 
+def SDT_RISCVWideningShiftLeft : SDTypeProfile<2, 2, [SDTCisVT<0, i32>,
+                                                      SDTCisSameAs<0, 1>,
+                                                      SDTCisSameAs<0, 2>,
+                                                      SDTCisSameAs<0, 3>]>;
+def riscv_wsll : RVSDNode<"WSLL", SDT_RISCVWideningShiftLeft>;
+def riscv_wsla : RVSDNode<"WSLA", SDT_RISCVWideningShiftLeft>;
+
 // Narrowing shift: res = nsrl(lo, hi, shamt) is equivalent to
 // res = truncate (srl (build_pair lo, hi), shamt), XLenVT
 def SDT_RISCVNarrowingShift : SDTypeProfile<1, 3, [SDTCisVT<0, i32>,
diff --git a/llvm/test/CodeGen/RISCV/rv32p.ll b/llvm/test/CodeGen/RISCV/rv32p.ll
index cc00f427126ba..fdc7d98e5d833 100644
--- a/llvm/test/CodeGen/RISCV/rv32p.ll
+++ b/llvm/test/CodeGen/RISCV/rv32p.ll
@@ -781,6 +781,46 @@ define i64 @wmulsu_i32(i32 %x, i32 %y) {
   ret i64 %c
 }
 
+define i64 @wsla_i32(i32 %x, i64 %y) {
+; CHECK-LABEL: wsla_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    wsla a0, a0, a1
+; CHECK-NEXT:    ret
+  %a = sext i32 %x to i64
+  %b = shl i64 %a, %y
+  ret i64 %b
+}
+
+define i64 @wsll_i32(i32 %x, i64 %y) {
+; CHECK-LABEL: wsll_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    wsll a0, a0, a1
+; CHECK-NEXT:    ret
+  %a = zext i32 %x to i64
+  %b = shl i64 %a, %y
+  ret i64 %b
+}
+
+define i64 @wslai_i32(i32 %x) {
+; CHECK-LABEL: wslai_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    wslai a0, a0, 23
+; CHECK-NEXT:    ret
+  %a = sext i32 %x to i64
+  %b = shl i64 %a, 23
+  ret i64 %b
+}
+
+define i64 @wslli_i32(i32 %x, i64 %y) {
+; CHECK-LABEL: wslli_i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    wslli a0, a0, 10
+; CHECK-NEXT:    ret
+  %a = zext i32 %x to i64
+  %b = shl i64 %a, 10
+  ret i64 %b
+}
+
 ; Test that mulh continues to be used with P.
 define i32 @mulh_i32(i32 %x, i32 %y) {
 ; CHECK-LABEL: mulh_i32:

``````````

</details>


https://github.com/llvm/llvm-project/pull/185079


More information about the llvm-commits mailing list