[llvm] e6a72a1 - [RISCV] Combine ADDD+WMULSU to WMACCSU (#180454)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 9 08:51:32 PST 2026


Author: Craig Topper
Date: 2026-02-09T08:51:27-08:00
New Revision: e6a72a1d42cdc605dd3c9bc3e0ee8add930b19f9

URL: https://github.com/llvm/llvm-project/commit/e6a72a1d42cdc605dd3c9bc3e0ee8add930b19f9
DIFF: https://github.com/llvm/llvm-project/commit/e6a72a1d42cdc605dd3c9bc3e0ee8add930b19f9.diff

LOG: [RISCV] Combine ADDD+WMULSU to WMACCSU (#180454)

Extend the existing combineADDDToWMACC DAG combine to also match
RISCVISD::WMULSU and produce RISCVISD::WMACCSU. This is similar to
how ADDD+UMUL_LOHI is combined to WMACCU and ADDD+SMUL_LOHI is
combined to WMACC.

This patch was generated by AI, but I reviewed it.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoP.td
    llvm/test/CodeGen/RISCV/rv32p.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 04258cd804888..db65d6ac1a5df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1910,12 +1910,13 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     CurDAG->RemoveDeadNode(Node);
     return;
   }
+  case RISCVISD::WMACCSU:
   case RISCVISD::WMACCU:
   case RISCVISD::WMACC: {
     assert(!Subtarget->is64Bit() && Subtarget->hasStdExtP() &&
            "Unexpected opcode");
 
-    // WMACCU/WMACC has 4 operands: (m1, m2, addlo, addhi) -> (lo, hi)
+    // WMACCU/WMACC/WMACCSU has 4 operands: (m1, m2, addlo, addhi) -> (lo, hi)
     SDValue M1 = Node->getOperand(0);
     SDValue M2 = Node->getOperand(1);
     SDValue AddLo = Node->getOperand(2);
@@ -1930,8 +1931,20 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
                                                  MVT::Untyped, AccOps),
                           0);
 
-    unsigned Opc =
-        Node->getOpcode() == RISCVISD::WMACCU ? RISCV::WMACCU : RISCV::WMACC;
+    unsigned Opc;
+    switch (Node->getOpcode()) {
+    default:
+      llvm_unreachable("Unexpected WMACC opcode");
+    case RISCVISD::WMACCU:
+      Opc = RISCV::WMACCU;
+      break;
+    case RISCVISD::WMACC:
+      Opc = RISCV::WMACC;
+      break;
+    case RISCVISD::WMACCSU:
+      Opc = RISCV::WMACCSU;
+      break;
+    }
 
     // Instruction format: WMACCU rd, rs1, rs2 (rd is accumulator, comes first)
     MachineSDNode *New =

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9b88bc5c39ce4..975baa7e2e504 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21067,6 +21067,8 @@ static SDValue performSHLCombine(SDNode *N,
 // (WMACCU x, y, a, b).
 // Combine (ADDD (SMUL_LOHI x, y).0, (SMUL_LOHI x, y).1, a, b) into
 // (WMACC x, y, a, b).
+// Combine (ADDD (WMULSU x, y).0, (WMULSU x, y).1, a, b) into
+// (WMACCSU x, y, a, b).
 static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
                                   const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == RISCVISD::ADDD && "Expected ADDD");
@@ -21074,28 +21076,30 @@ static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
          "ADDD requires RV32 with P extension");
 
   // ADDD has 4 operands: (op0_lo, op0_hi, op1_lo, op1_hi)
-  // Try to match UMUL_LOHI or SMUL_LOHI in either operand pair due to
+  // Try to match UMUL_LOHI, SMUL_LOHI, or WMULSU in either operand pair due to
   // commutativity
   SDValue Op0Lo = N->getOperand(0);
   SDValue Op0Hi = N->getOperand(1);
   SDValue Op1Lo = N->getOperand(2);
   SDValue Op1Hi = N->getOperand(3);
 
+  auto IsSupportedMul = [](unsigned Opc) {
+    return Opc == ISD::UMUL_LOHI || Opc == ISD::SMUL_LOHI ||
+           Opc == RISCVISD::WMULSU;
+  };
+
   SDNode *MulNode = nullptr;
   SDValue AddLo, AddHi;
 
-  // Check if first operand pair is UMUL_LOHI or SMUL_LOHI
-  if ((Op0Lo.getOpcode() == ISD::UMUL_LOHI ||
-       Op0Lo.getOpcode() == ISD::SMUL_LOHI) &&
-      Op0Lo.getNode() == Op0Hi.getNode() && Op0Lo.getResNo() == 0 &&
-      Op0Hi.getResNo() == 1) {
+  // Check if first operand pair is a supported multiply
+  if (IsSupportedMul(Op0Lo.getOpcode()) && Op0Lo.getNode() == Op0Hi.getNode() &&
+      Op0Lo.getResNo() == 0 && Op0Hi.getResNo() == 1) {
     MulNode = Op0Lo.getNode();
     AddLo = Op1Lo;
     AddHi = Op1Hi;
   }
-  // Check if second operand pair is UMUL_LOHI or SMUL_LOHI (commutative case)
-  else if ((Op1Lo.getOpcode() == ISD::UMUL_LOHI ||
-            Op1Lo.getOpcode() == ISD::SMUL_LOHI) &&
+  // Check if second operand pair is a supported multiply (commutative case)
+  else if (IsSupportedMul(Op1Lo.getOpcode()) &&
            Op1Lo.getNode() == Op1Hi.getNode() && Op1Lo.getResNo() == 0 &&
            Op1Hi.getResNo() == 1) {
     MulNode = Op1Lo.getNode();
@@ -21113,10 +21117,22 @@ static SDValue combineADDDToWMACC(SDNode *N, SelectionDAG &DAG,
   SDValue MulOp0 = MulNode->getOperand(0);
   SDValue MulOp1 = MulNode->getOperand(1);
 
-  // Create WMACCU or WMACC node: (m1, m2, addlo, addhi) -> (lo, hi)
+  // Create WMACCU, WMACC, or WMACCSU node: (m1, m2, addlo, addhi) -> (lo, hi)
   SDLoc DL(N);
-  bool IsSigned = MulNode->getOpcode() == ISD::SMUL_LOHI;
-  unsigned Opc = IsSigned ? RISCVISD::WMACC : RISCVISD::WMACCU;
+  unsigned Opc;
+  switch (MulNode->getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected multiply opcode");
+  case ISD::UMUL_LOHI:
+    Opc = RISCVISD::WMACCU;
+    break;
+  case ISD::SMUL_LOHI:
+    Opc = RISCVISD::WMACC;
+    break;
+  case RISCVISD::WMULSU:
+    Opc = RISCVISD::WMACCSU;
+    break;
+  }
   return DAG.getNode(Opc, DL, DAG.getVTList(MVT::i32, MVT::i32), MulOp0, MulOp1,
                      AddLo, AddHi);
 }

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 560320465d1c9..774e1e024a4be 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1510,6 +1510,7 @@ def riscv_wmacc  : RVSDNode<"WMACC", SDT_RISCVWideningMAccW,
                             [SDNPCommutative]>;
 def riscv_wmaccu : RVSDNode<"WMACCU", SDT_RISCVWideningMAccW,
                             [SDNPCommutative]>;
+def riscv_wmaccsu : RVSDNode<"WMACCSU", SDT_RISCVWideningMAccW>;
 
 // MULH/MULHU/MULHSU with rounding.
 def riscv_mulhr : RVSDNode<"MULHR", SDTIntBinOp>;

diff  --git a/llvm/test/CodeGen/RISCV/rv32p.ll b/llvm/test/CodeGen/RISCV/rv32p.ll
index 8c45d0dac5baf..1e31983df0b8c 100644
--- a/llvm/test/CodeGen/RISCV/rv32p.ll
+++ b/llvm/test/CodeGen/RISCV/rv32p.ll
@@ -678,6 +678,34 @@ define i64 @wmacc_commute(i32 %a, i32 %b, i64 %c) nounwind {
   ret i64 %result
 }
 
+define i64 @wmaccsu(i32 %a, i32 %b, i64 %c) nounwind {
+; CHECK-LABEL: wmaccsu:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    wmaccsu a2, a0, a1
+; CHECK-NEXT:    mv a0, a2
+; CHECK-NEXT:    mv a1, a3
+; CHECK-NEXT:    ret
+  %aext = sext i32 %a to i64
+  %bext = zext i32 %b to i64
+  %mul = mul i64 %aext, %bext
+  %result = add i64 %c, %mul
+  ret i64 %result
+}
+
+define i64 @wmaccsu_commute(i32 %a, i32 %b, i64 %c) nounwind {
+; CHECK-LABEL: wmaccsu_commute:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    wmaccsu a2, a0, a1
+; CHECK-NEXT:    mv a0, a2
+; CHECK-NEXT:    mv a1, a3
+; CHECK-NEXT:    ret
+  %aext = sext i32 %a to i64
+  %bext = zext i32 %b to i64
+  %mul = mul i64 %aext, %bext
+  %result = add i64 %mul, %c
+  ret i64 %result
+}
+
 ; Negative test: multiply result has multiple uses, should not combine
 define void @wmaccu_multiple_uses(i32 %a, i32 %b, i64 %c, ptr %out1, ptr %out2) nounwind {
 ; CHECK-LABEL: wmaccu_multiple_uses:


        


More information about the llvm-commits mailing list