[llvm] [RISCV] Add DAG combine for forming VAADDU_VL from VP intrinsics. (PR #124848)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 19:08:17 PST 2025


================
@@ -16373,6 +16371,92 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
       VPStore->isTruncatingStore(), VPStore->isCompressingStore());
 }
 
+// Peephole avgceil pattern.
+//   %1 = zext <N x i8> %a to <N x i32>
+//   %2 = zext <N x i8> %b to <N x i32>
+//   %3 = add nuw nsw <N x i32> %1, splat (i32 1)
+//   %4 = add nuw nsw <N x i32> %3, %2
+//   %5 = lshr <N x i32> %4, splat (i32 1)
+//   %6 = trunc <N x i32> %5 to <N x i8>
+static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
+                                         const RISCVSubtarget &Subtarget) {
+  EVT VT = N->getValueType(0);
+
+  // Ignore fixed vectors.
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (!VT.isScalableVector() || !TLI.isTypeLegal(VT))
+    return SDValue();
+
+  SDValue In = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  // Input should be a vp_srl with same mask and VL.
+  if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask ||
+      In.getOperand(3) != VL)
+    return SDValue();
+
+  // Shift amount should be 1.
+  if (!isOneOrOneSplat(In.getOperand(1)))
+    return SDValue();
+
+  // Shifted value should be a vp_add with same mask and VL.
+  SDValue LHS = In.getOperand(0);
+  if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask ||
+      LHS.getOperand(3) != VL)
+    return SDValue();
+
+  SDValue Operands[3];
+
+  // Matches another VP_ADD with same VL and Mask.
+  auto FindAdd = [&](SDValue V, SDValue Other) {
+    if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
+        V.getOperand(3) != VL)
+      return false;
+
+    Operands[0] = Other;
+    Operands[1] = V.getOperand(1);
+    Operands[2] = V.getOperand(0);
+    return true;
+  };
+
+  // We need to find another VP_ADD in one of the operands.
+  SDValue LHS0 = LHS.getOperand(0);
+  SDValue LHS1 = LHS.getOperand(1);
+  if (!FindAdd(LHS0, LHS1) && !FindAdd(LHS1, LHS0))
+    return SDValue();
+
+  // Now we have three operands of two additions. Check that one of them is a
+  // constant vector with ones.
+  auto I = llvm::find_if(Operands,
+                         [](const SDValue &Op) { return isOneOrOneSplat(Op); });
+  if (I == std::end(Operands))
+    return SDValue();
+  // We found a vector with ones, move if it to the end of the Operands array.
+  std::swap(*I, Operands[2]);
+
+  // Make sure the other 2 operands can be promoted from the result type.
+  for (SDValue Op : drop_end(Operands)) {
+    if (Op.getOpcode() != ISD::VP_ZERO_EXTEND || Op.getOperand(1) != Mask ||
+        Op.getOperand(2) != VL)
+      return SDValue();
+    // Input must be smaller than our result.
----------------
preames wrote:

smaller but the code checks less than equal.  I think the comment just needs updated.  

https://github.com/llvm/llvm-project/pull/124848


More information about the llvm-commits mailing list