[llvm-branch-commits] [llvm] [RISCV] Combine vwaddu_wv+vabd(u) to vwabda(u) (PR #184603)

Pengcheng Wang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Mar 9 19:38:45 PDT 2026


================
@@ -19146,6 +19147,88 @@ static SDValue performVWABDACombine(SDNode *N, SelectionDAG &DAG,
   return Result;
 }
 
+// vwaddu_wv C (vabd A B) -> vwabda(A B C)
+// vwaddu_wv C (zext (vabd A B)) -> vwabda(A (sext B) (sext C))
+// vwaddu_wv C (vabdu A B) -> vwabdau(A B C)
+// vwaddu_wv C (zext (vabdu A B)) -> vwabdau(A (zext B) (zext C))
+static SDValue performVWABDACombineWV(SDNode *N, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
+  if (!Subtarget.hasStdExtZvabd())
+    return SDValue();
+
+  MVT VT = N->getSimpleValueType(0);
+  // The result is widened, so we can accept i16/i32 here.
+  if (VT.getVectorElementType() != MVT::i16 &&
+      VT.getVectorElementType() != MVT::i32)
+    return SDValue();
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Passthru = N->getOperand(2);
+  if (!Passthru->isUndef())
+    return SDValue();
+
+  SDValue Mask = N->getOperand(3);
+  SDValue VL = N->getOperand(4);
+  unsigned ExtOpc = 0;
+  MVT ExtVT;
+  auto GetDiff = [&](SDValue Op) {
+    unsigned Opc = Op.getOpcode();
+    if (Opc == RISCVISD::VZEXT_VL) {
+      SDValue Src = Op->getOperand(0);
+      unsigned SrcOpc = Src.getOpcode();
+      switch (SrcOpc) {
+      default:
+        return SDValue();
+      case ISD::ABDS:
+      case RISCVISD::ABDS_VL:
+        ExtOpc = RISCVISD::VSEXT_VL;
+        break;
+      case ISD::ABDU:
+      case RISCVISD::ABDU_VL:
+        ExtOpc = RISCVISD::VZEXT_VL;
+        break;
+      }
+      ExtVT = Op->getSimpleValueType(0);
+      return Src;
+    }
+
+    if (Opc != ISD::ABDS && Opc != ISD::ABDU && Opc != RISCVISD::ABDS_VL &&
+        Opc != RISCVISD::ABDU_VL)
+      return SDValue();
+    return Op;
+  };
+
+  auto ExtractOps = [&](SDValue Op0,
+                        SDValue Op1) -> std::pair<SDValue, SDValue> {
+    SDValue Diff = GetDiff(Op0);
+    if (Diff)
+      return {Op1, Diff};
+    Diff = GetDiff(Op1);
+    if (Diff)
+      return {Op0, Diff};
+    return {};
+  };
+
+  auto [Acc, Diff] = ExtractOps(Op0, Op1);
+  if (!Diff)
+    return SDValue();
----------------
wangpc-pp wrote:

Done in https://github.com/llvm/llvm-project/pull/184603/commits/5111b034482743376fd40b535bbac3f2de3ca017.

https://github.com/llvm/llvm-project/pull/184603


More information about the llvm-branch-commits mailing list