[llvm] e6a72a1 - [RISCV] Combine ADDD+WMULSU to WMACCSU (#180454)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 9 08:51:32 PST 2026
Author: Craig Topper
Date: 2026-02-09T08:51:27-08:00
New Revision: e6a72a1d42cdc605dd3c9bc3e0ee8add930b19f9
URL: https://github.com/llvm/llvm-project/commit/e6a72a1d42cdc605dd3c9bc3e0ee8add930b19f9
DIFF: https://github.com/llvm/llvm-project/commit/e6a72a1d42cdc605dd3c9bc3e0ee8add930b19f9.diff
LOG: [RISCV] Combine ADDD+WMULSU to WMACCSU (#180454)
Extend the existing combineADDDToWMACC DAG combine to also match
RISCVISD::WMULSU and produce RISCVISD::WMACCSU. This is similar to
how ADDD+UMUL_LOHI is combined to WMACCU and ADDD+SMUL_LOHI is
combined to WMACC.
This patch was generated by AI, but I reviewed it.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVInstrInfoP.td
llvm/test/CodeGen/RISCV/rv32p.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 04258cd804888..db65d6ac1a5df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1910,12 +1910,13 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
CurDAG->RemoveDeadNode(Node);
return;
}
+ case RISCVISD::WMACCSU:
case RISCVISD::WMACCU:
case RISCVISD::WMACC: {
assert(!Subtarget->is64Bit() && Subtarget->hasStdExtP() &&
"Unexpected opcode");
- // WMACCU/WMACC has 4 operands: (m1, m2, addlo, addhi) -> (lo, hi)
+ // WMACCU/WMACC/WMACCSU has 4 operands: (m1, m2, addlo, addhi) -> (lo, hi)
SDValue M1 = Node->getOperand(0);
SDValue M2 = Node->getOperand(1);
SDValue AddLo = Node->getOperand(2);
@@ -1930,8 +1931,20 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
MVT::Untyped, AccOps),
0);
- unsigned Opc =
- Node->getOpcode() == RISCVISD::WMACCU ? RISCV::WMACCU : RISCV::WMACC;
+ unsigned Opc;
+ switch (Node->getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected WMACC opcode");
+ case RISCVISD::WMACCU:
+ Opc = RISCV::WMACCU;
+ break;
+ case RISCVISD::WMACC:
+ Opc = RISCV::WMACC;
+ break;
+ case RISCVISD::WMACCSU:
+ Opc = RISCV::WMACCSU;
+ break;
+ }
// Instruction format: WMACCU rd, rs1, rs2 (rd is accumulator, comes first)
MachineSDNode *New =
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9b88bc5c39ce4..975baa7e2e504 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21067,6 +21067,8 @@ static SDValue performSHLCombine(SDNode *N,
// (WMACCU x, y, a, b).
// Combine (ADDD (SMUL_LOHI x, y).0, (SMUL_LOHI x, y).1, a, b) into
// (WMACC x, y, a, b).
+// Combine (ADDD (WMULSU x, y).0, (WMULSU x, y).1, a, b) into
+// (WMACCSU x, y, a, b).
static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == RISCVISD::ADDD && "Expected ADDD");
@@ -21074,28 +21076,30 @@ static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
"ADDD requires RV32 with P extension");
// ADDD has 4 operands: (op0_lo, op0_hi, op1_lo, op1_hi)
- // Try to match UMUL_LOHI or SMUL_LOHI in either operand pair due to
+ // Try to match UMUL_LOHI, SMUL_LOHI, or WMULSU in either operand pair due to
// commutativity
SDValue Op0Lo = N->getOperand(0);
SDValue Op0Hi = N->getOperand(1);
SDValue Op1Lo = N->getOperand(2);
SDValue Op1Hi = N->getOperand(3);
+ auto IsSupportedMul = [](unsigned Opc) {
+ return Opc == ISD::UMUL_LOHI || Opc == ISD::SMUL_LOHI ||
+ Opc == RISCVISD::WMULSU;
+ };
+
SDNode *MulNode = nullptr;
SDValue AddLo, AddHi;
- // Check if first operand pair is UMUL_LOHI or SMUL_LOHI
- if ((Op0Lo.getOpcode() == ISD::UMUL_LOHI ||
- Op0Lo.getOpcode() == ISD::SMUL_LOHI) &&
- Op0Lo.getNode() == Op0Hi.getNode() && Op0Lo.getResNo() == 0 &&
- Op0Hi.getResNo() == 1) {
+ // Check if first operand pair is a supported multiply
+ if (IsSupportedMul(Op0Lo.getOpcode()) && Op0Lo.getNode() == Op0Hi.getNode() &&
+ Op0Lo.getResNo() == 0 && Op0Hi.getResNo() == 1) {
MulNode = Op0Lo.getNode();
AddLo = Op1Lo;
AddHi = Op1Hi;
}
- // Check if second operand pair is UMUL_LOHI or SMUL_LOHI (commutative case)
- else if ((Op1Lo.getOpcode() == ISD::UMUL_LOHI ||
- Op1Lo.getOpcode() == ISD::SMUL_LOHI) &&
+ // Check if second operand pair is a supported multiply (commutative case)
+ else if (IsSupportedMul(Op1Lo.getOpcode()) &&
Op1Lo.getNode() == Op1Hi.getNode() && Op1Lo.getResNo() == 0 &&
Op1Hi.getResNo() == 1) {
MulNode = Op1Lo.getNode();
@@ -21113,10 +21117,22 @@ static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
SDValue MulOp0 = MulNode->getOperand(0);
SDValue MulOp1 = MulNode->getOperand(1);
- // Create WMACCU or WMACC node: (m1, m2, addlo, addhi) -> (lo, hi)
+ // Create WMACCU, WMACC, or WMACCSU node: (m1, m2, addlo, addhi) -> (lo, hi)
SDLoc DL(N);
- bool IsSigned = MulNode->getOpcode() == ISD::SMUL_LOHI;
- unsigned Opc = IsSigned ? RISCVISD::WMACC : RISCVISD::WMACCU;
+ unsigned Opc;
+ switch (MulNode->getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected multiply opcode");
+ case ISD::UMUL_LOHI:
+ Opc = RISCVISD::WMACCU;
+ break;
+ case ISD::SMUL_LOHI:
+ Opc = RISCVISD::WMACC;
+ break;
+ case RISCVISD::WMULSU:
+ Opc = RISCVISD::WMACCSU;
+ break;
+ }
return DAG.getNode(Opc, DL, DAG.getVTList(MVT::i32, MVT::i32), MulOp0, MulOp1,
AddLo, AddHi);
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 560320465d1c9..774e1e024a4be 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1510,6 +1510,7 @@ def riscv_wmacc : RVSDNode<"WMACC", SDT_RISCVWideningMAccW,
[SDNPCommutative]>;
def riscv_wmaccu : RVSDNode<"WMACCU", SDT_RISCVWideningMAccW,
[SDNPCommutative]>;
+def riscv_wmaccsu : RVSDNode<"WMACCSU", SDT_RISCVWideningMAccW>;
// MULH/MULHU/MULHSU with rounding.
def riscv_mulhr : RVSDNode<"MULHR", SDTIntBinOp>;
diff --git a/llvm/test/CodeGen/RISCV/rv32p.ll b/llvm/test/CodeGen/RISCV/rv32p.ll
index 8c45d0dac5baf..1e31983df0b8c 100644
--- a/llvm/test/CodeGen/RISCV/rv32p.ll
+++ b/llvm/test/CodeGen/RISCV/rv32p.ll
@@ -678,6 +678,34 @@ define i64 @wmacc_commute(i32 %a, i32 %b, i64 %c) nounwind {
ret i64 %result
}
+define i64 @wmaccsu(i32 %a, i32 %b, i64 %c) nounwind {
+; CHECK-LABEL: wmaccsu:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wmaccsu a2, a0, a1
+; CHECK-NEXT: mv a0, a2
+; CHECK-NEXT: mv a1, a3
+; CHECK-NEXT: ret
+ %aext = sext i32 %a to i64
+ %bext = zext i32 %b to i64
+ %mul = mul i64 %aext, %bext
+ %result = add i64 %c, %mul
+ ret i64 %result
+}
+
+define i64 @wmaccsu_commute(i32 %a, i32 %b, i64 %c) nounwind {
+; CHECK-LABEL: wmaccsu_commute:
+; CHECK: # %bb.0:
+; CHECK-NEXT: wmaccsu a2, a0, a1
+; CHECK-NEXT: mv a0, a2
+; CHECK-NEXT: mv a1, a3
+; CHECK-NEXT: ret
+ %aext = sext i32 %a to i64
+ %bext = zext i32 %b to i64
+ %mul = mul i64 %aext, %bext
+ %result = add i64 %mul, %c
+ ret i64 %result
+}
+
; Negative test: multiply result has multiple uses, should not combine
define void @wmaccu_multiple_uses(i32 %a, i32 %b, i64 %c, ptr %out1, ptr %out2) nounwind {
; CHECK-LABEL: wmaccu_multiple_uses:
More information about the llvm-commits
mailing list