[llvm] [RISCV] Fold (WADDAU -C, -1, rs1, 0) -> (WSUBU rs1, C) where C > 0 (PR #186638)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 14 20:27:56 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Craig Topper (topperc)
<details>
<summary>Changes</summary>
Stacked on #<!-- -->186635
---
Full diff: https://github.com/llvm/llvm-project/pull/186638.diff
4 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+11-4)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+28)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoP.td (+3)
- (modified) llvm/test/CodeGen/RISCV/rv32p.ll (+37-5)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 76ecd4fccfd85..4a81f87153ff2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1842,8 +1842,9 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
}
case ISD::SMUL_LOHI:
case ISD::UMUL_LOHI:
- case RISCVISD::WMULSU: {
- // Custom select (S/U)MUL_LOHI to WMUL(U) for RV32P.
+ case RISCVISD::WMULSU:
+ case RISCVISD::WADDU:
+ case RISCVISD::WSUBU: {
assert(Subtarget->hasStdExtP() && !Subtarget->is64Bit() && VT == MVT::i32 &&
"Unexpected opcode");
@@ -1860,12 +1861,18 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
case RISCVISD::WMULSU:
Opc = RISCV::WMULSU;
break;
+ case RISCVISD::WADDU:
+ Opc = RISCV::WADDU;
+ break;
+ case RISCVISD::WSUBU:
+ Opc = RISCV::WSUBU;
+ break;
}
- SDNode *WMUL = CurDAG->getMachineNode(
+ SDNode *Result = CurDAG->getMachineNode(
Opc, DL, MVT::Untyped, Node->getOperand(0), Node->getOperand(1));
- auto [Lo, Hi] = extractGPRPair(CurDAG, DL, SDValue(WMUL, 0));
+ auto [Lo, Hi] = extractGPRPair(CurDAG, DL, SDValue(Result, 0));
ReplaceUses(SDValue(Node, 0), Lo);
ReplaceUses(SDValue(Node, 1), Hi);
CurDAG->RemoveDeadNode(Node);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5156145e35aa2..eb253052a0b2a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21613,6 +21613,27 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
SDValue Op1 = N->getOperand(2);
SDValue Op2 = N->getOperand(3);
+ // (WADDAU lo, 0, rs1, 0) -> (WADDU lo, rs1)
+ if (isNullConstant(Op0Hi) && isNullConstant(Op2)) {
+ SDValue Result = DAG.getNode(
+ RISCVISD::WADDU, DL, DAG.getVTList(MVT::i32, MVT::i32), Op0Lo, Op1);
+ return DCI.CombineTo(N, Result.getValue(0), Result.getValue(1));
+ }
+
+ // (WADDAU -C, -1, rs1, 0) -> (WSUBU rs1, C) where C > 0
+ if (isNullConstant(Op2) && isAllOnesConstant(Op0Hi)) {
+ if (auto *C0 = dyn_cast<ConstantSDNode>(Op0Lo)) {
+ int64_t Val = C0->getSExtValue();
+ if (Val < 0) {
+ SDValue PosConst = DAG.getConstant(-Val, DL, MVT::i32);
+ SDValue Result =
+ DAG.getNode(RISCVISD::WSUBU, DL,
+ DAG.getVTList(MVT::i32, MVT::i32), Op1, PosConst);
+ return DCI.CombineTo(N, Result.getValue(0), Result.getValue(1));
+ }
+ }
+ }
+
// FIXME: Canonicalize zero Op1 to Op2.
if (isNullConstant(Op2) && Op0Lo.getNode() == Op0Hi.getNode() &&
Op0Lo.getResNo() == 0 && Op0Hi.getResNo() == 1 && Op0Lo.hasOneUse() &&
@@ -21644,6 +21665,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
SDValue Op1 = N->getOperand(2);
SDValue Op2 = N->getOperand(3);
+ // (WSUBAU lo, 0, 0, rs2) -> (WSUBU lo, rs2)
+ if (isNullConstant(Op0Hi) && isNullConstant(Op1)) {
+ SDValue Result = DAG.getNode(
+ RISCVISD::WSUBU, DL, DAG.getVTList(MVT::i32, MVT::i32), Op0Lo, Op2);
+ return DCI.CombineTo(N, Result.getValue(0), Result.getValue(1));
+ }
+
// (WSUBAU (WADDAU lo, hi, a, 0), 0, b) -> (WSUBAU lo, hi, a, b)
if (isNullConstant(Op1) && Op0Lo.getOpcode() == RISCVISD::WADDAU &&
Op0Lo.getNode() == Op0Hi.getNode() && Op0Lo.getResNo() == 0 &&
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 23950c4478a1b..788ecda2a1df1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1631,6 +1631,9 @@ def riscv_waddau : RVSDNode<"WADDAU", SDT_RISCVWideningAddSubAccumulate>;
// Widening sub accumulate unsigned: rd = rd + zext(rs1) - zext(rs2)
def riscv_wsubau : RVSDNode<"WSUBAU", SDT_RISCVWideningAddSubAccumulate>;
+def riscv_waddu : RVSDNode<"WADDU", SDTIntBinHiLoOp, [SDNPCommutative]>;
+def riscv_wsubu : RVSDNode<"WSUBU", SDTIntBinHiLoOp>;
+
def riscv_wmulsu : RVSDNode<"WMULSU", SDTIntBinHiLoOp>;
def SDT_RISCVWideningShiftLeft : SDTypeProfile<2, 2, [SDTCisVT<0, i32>,
diff --git a/llvm/test/CodeGen/RISCV/rv32p.ll b/llvm/test/CodeGen/RISCV/rv32p.ll
index fdc7d98e5d833..4d009ef9ca76b 100644
--- a/llvm/test/CodeGen/RISCV/rv32p.ll
+++ b/llvm/test/CodeGen/RISCV/rv32p.ll
@@ -186,15 +186,14 @@ define i64 @cls_i64(i64 %x) {
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: xor a0, a0, a2
; CHECK-NEXT: clz a0, a0
-; CHECK-NEXT: addi a2, a0, 32
+; CHECK-NEXT: addi a0, a0, 32
; CHECK-NEXT: j .LBB15_3
; CHECK-NEXT: .LBB15_2:
; CHECK-NEXT: xor a1, a1, a2
-; CHECK-NEXT: clz a2, a1
+; CHECK-NEXT: clz a0, a1
; CHECK-NEXT: .LBB15_3:
-; CHECK-NEXT: li a0, -1
-; CHECK-NEXT: mv a1, a0
-; CHECK-NEXT: waddau a0, a2, zero
+; CHECK-NEXT: li a1, 1
+; CHECK-NEXT: wsubu a0, a0, a1
; CHECK-NEXT: ret
%a = ashr i64 %x, 63
%b = xor i64 %x, %a
@@ -1210,3 +1209,36 @@ define i64 @wsubau_zext_chain_rev(i64 %acc, i32 %a, i32 %b) nounwind {
%sum = add i64 %sub, %ext_b
ret i64 %sum
}
+
+define i64 @waddu(i32 %a, i32 %b) nounwind {
+; CHECK-LABEL: waddu:
+; CHECK: # %bb.0:
+; CHECK-NEXT: waddu a0, a0, a1
+; CHECK-NEXT: ret
+ %ext_a = zext i32 %a to i64
+ %ext_b = zext i32 %b to i64
+ %sum = add i64 %ext_a, %ext_b
+ ret i64 %sum
+}
+
+define i64 @wsubu(i32 %a, i32 %b) nounwind {
+; CHECK-LABEL: wsubu:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wsubu a0, a0, a1
+; CHECK-NEXT: ret
+ %ext_a = zext i32 %a to i64
+ %ext_b = zext i32 %b to i64
+ %diff = sub i64 %ext_a, %ext_b
+ ret i64 %diff
+}
+
+define i64 @wsub_from_neg_const(i32 %a) nounwind {
+; CHECK-LABEL: wsub_from_neg_const:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a1, 42
+; CHECK-NEXT: wsubu a0, a0, a1
+; CHECK-NEXT: ret
+ %ext_a = zext i32 %a to i64
+ %sum = add i64 %ext_a, -42
+ ret i64 %sum
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/186638
More information about the llvm-commits
mailing list