[llvm] [RISCV][P-ext] Custom legalize i64 SHL to WSLL(I)/WSLA(I) (PR #185079)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 6 10:47:19 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Craig Topper (topperc)
<details>
<summary>Changes</summary>
When input is zero or sign extended.
---
Full diff: https://github.com/llvm/llvm-project/pull/185079.diff
4 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+29)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+25-5)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+7)
- (modified) llvm/test/CodeGen/RISCV/rv32p.ll (+40)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 3492e60662380..8e4913218c1df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1871,6 +1871,35 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
CurDAG->RemoveDeadNode(Node);
return;
}
+ case RISCVISD::WSLL:
+ case RISCVISD::WSLA: {
+ // Custom select (S/U)MUL_LOHI to WMUL(U) for RV32P.
+ assert(Subtarget->hasStdExtP() && !Subtarget->is64Bit() && VT == MVT::i32 &&
+ "Unexpected opcode");
+
+ bool IsSigned = Node->getOpcode() == RISCVISD::WSLA;
+
+ SDValue ShAmt = Node->getOperand(1);
+
+ unsigned Opc;
+
+ auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt);
+ if (ShAmtC && ShAmtC->getZExtValue() < 64) {
+ Opc = IsSigned ? RISCV::WSLAI : RISCV::WSLLI;
+ ShAmt = CurDAG->getTargetConstant(ShAmtC->getZExtValue(), DL, XLenVT);
+ } else {
+ Opc = IsSigned ? RISCV::WSLA : RISCV::WSLL;
+ }
+
+ SDNode *WShift = CurDAG->getMachineNode(Opc, DL, MVT::Untyped,
+ Node->getOperand(0), ShAmt);
+
+ auto [Lo, Hi] = extractGPRPair(CurDAG, DL, SDValue(WShift, 0));
+ ReplaceUses(SDValue(Node, 0), Lo);
+ ReplaceUses(SDValue(Node, 1), Hi);
+ CurDAG->RemoveDeadNode(Node);
+ return;
+ }
case ISD::LOAD: {
if (tryIndexedLoad(Node))
return;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 891bc22a7463d..29ff12aa96efd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15418,6 +15418,26 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
assert(!Subtarget.is64Bit() && Subtarget.hasStdExtP() &&
"Unexpected custom legalisation");
+ SDValue LHS = N->getOperand(0);
+ SDValue ShAmt = N->getOperand(1);
+
+ unsigned WideOpc = 0;
+ APInt HighMask = APInt::getHighBitsSet(64, 32);
+ if (DAG.MaskedValueIsZero(LHS, HighMask))
+ WideOpc = RISCVISD::WSLL;
+ else if (DAG.ComputeMaxSignificantBits(LHS) <= 32)
+ WideOpc = RISCVISD::WSLA;
+
+ if (WideOpc) {
+ SDValue Res =
+ DAG.getNode(WideOpc, DL, DAG.getVTList(MVT::i32, MVT::i32),
+ DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, LHS),
+ DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, ShAmt));
+ Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0),
+ Res, Res.getValue(1)));
+ return;
+ }
+
// Only handle constant shifts < 32. Non-constant shifts are handled by
// lowerShiftLeftParts/lowerShiftRightParts, and shifts >= 32 use default
// legalization.
@@ -15425,22 +15445,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
if (!ShAmtC || ShAmtC->getZExtValue() >= 32)
break;
- auto [Lo, Hi] = DAG.SplitScalar(N->getOperand(0), DL, MVT::i32, MVT::i32);
+ auto [Lo, Hi] = DAG.SplitScalar(LHS, DL, MVT::i32, MVT::i32);
SDValue LoRes, HiRes;
if (N->getOpcode() == ISD::SHL) {
// Lo = slli Lo, shamt
// Hi = nsrli {Hi, Lo}, (32 - shamt)
uint64_t ShAmtVal = ShAmtC->getZExtValue();
- LoRes = DAG.getNode(ISD::SHL, DL, MVT::i32, Lo, N->getOperand(1));
+ LoRes = DAG.getNode(ISD::SHL, DL, MVT::i32, Lo, ShAmt);
HiRes = DAG.getNode(RISCVISD::NSRL, DL, MVT::i32, Lo, Hi,
DAG.getConstant(32 - ShAmtVal, DL, MVT::i32));
} else {
bool IsSRA = N->getOpcode() == ISD::SRA;
LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL,
- MVT::i32, Lo, Hi, N->getOperand(1));
- HiRes = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi,
- N->getOperand(1));
+ MVT::i32, Lo, Hi, ShAmt);
+ HiRes =
+ DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi, ShAmt);
}
SDValue Res = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, LoRes, HiRes);
Results.push_back(Res);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 7bb9ad5feb219..ea16cf28bfd7c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1633,6 +1633,13 @@ def riscv_wsubau : RVSDNode<"WSUBAU", SDT_RISCVWideningAddSubAccumulate>;
def riscv_wmulsu : RVSDNode<"WMULSU", SDTIntBinHiLoOp>;
+def SDT_RISCVWideningShiftLeft : SDTypeProfile<2, 2, [SDTCisVT<0, i32>,
+ SDTCisSameAs<0, 1>,
+ SDTCisSameAs<0, 2>,
+ SDTCisSameAs<0, 3>]>;
+def riscv_wsll : RVSDNode<"WSLL", SDT_RISCVWideningShiftLeft>;
+def riscv_wsla : RVSDNode<"WSLA", SDT_RISCVWideningShiftLeft>;
+
// Narrowing shift: res = nsrl(lo, hi, shamt) is equivalent to
// res = truncate (srl (build_pair lo, hi), shamt), XLenVT
def SDT_RISCVNarrowingShift : SDTypeProfile<1, 3, [SDTCisVT<0, i32>,
diff --git a/llvm/test/CodeGen/RISCV/rv32p.ll b/llvm/test/CodeGen/RISCV/rv32p.ll
index cc00f427126ba..fdc7d98e5d833 100644
--- a/llvm/test/CodeGen/RISCV/rv32p.ll
+++ b/llvm/test/CodeGen/RISCV/rv32p.ll
@@ -781,6 +781,46 @@ define i64 @wmulsu_i32(i32 %x, i32 %y) {
ret i64 %c
}
+define i64 @wsla_i32(i32 %x, i64 %y) {
+; CHECK-LABEL: wsla_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wsla a0, a0, a1
+; CHECK-NEXT: ret
+ %a = sext i32 %x to i64
+ %b = shl i64 %a, %y
+ ret i64 %b
+}
+
+define i64 @wsll_i32(i32 %x, i64 %y) {
+; CHECK-LABEL: wsll_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wsll a0, a0, a1
+; CHECK-NEXT: ret
+ %a = zext i32 %x to i64
+ %b = shl i64 %a, %y
+ ret i64 %b
+}
+
+define i64 @wslai_i32(i32 %x) {
+; CHECK-LABEL: wslai_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wslai a0, a0, 23
+; CHECK-NEXT: ret
+ %a = sext i32 %x to i64
+ %b = shl i64 %a, 23
+ ret i64 %b
+}
+
+define i64 @wslli_i32(i32 %x, i64 %y) {
+; CHECK-LABEL: wslli_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wslli a0, a0, 10
+; CHECK-NEXT: ret
+ %a = zext i32 %x to i64
+ %b = shl i64 %a, 10
+ ret i64 %b
+}
+
; Test that mulh continues to be used with P.
define i32 @mulh_i32(i32 %x, i32 %y) {
; CHECK-LABEL: mulh_i32:
``````````
</details>
https://github.com/llvm/llvm-project/pull/185079
More information about the llvm-commits
mailing list