[llvm] [RISCV] Add DAG combine for forming VAADDU_VL from VP intrinsics. (PR #124848)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 30 07:58:21 PST 2025


================
@@ -16373,6 +16371,92 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
       VPStore->isTruncatingStore(), VPStore->isCompressingStore());
 }
 
+// Peephole avgceil pattern.
+//   %1 = zext <N x i8> %a to <N x i32>
+//   %2 = zext <N x i8> %b to <N x i32>
+//   %3 = add nuw nsw <N x i32> %1, splat (i32 1)
+//   %4 = add nuw nsw <N x i32> %3, %2
+//   %5 = lshr <N x i32> %4, splat (i32 1)
+//   %6 = trunc <N x i32> %5 to <N x i8>
+static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
+                                         const RISCVSubtarget &Subtarget) {
+  EVT VT = N->getValueType(0);
+
+  // Ignore fixed vectors.
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (!VT.isScalableVector() || !TLI.isTypeLegal(VT))
+    return SDValue();
+
+  SDValue In = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  // Input should be a vp_srl with same mask and VL.
+  if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask ||
+      In.getOperand(3) != VL)
+    return SDValue();
+
+  // Shift amount should be 1.
+  if (!isOneOrOneSplat(In.getOperand(1)))
+    return SDValue();
+
+  // Shifted value should be a vp_add with same mask and VL.
+  SDValue LHS = In.getOperand(0);
+  if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask ||
+      LHS.getOperand(3) != VL)
+    return SDValue();
+
+  SDValue Operands[3];
+
+  // Matches another VP_ADD with same VL and Mask.
+  auto FindAdd = [&](SDValue V, SDValue Other) {
+    if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
+        V.getOperand(3) != VL)
+      return false;
+
+    Operands[0] = Other;
+    Operands[1] = V.getOperand(1);
+    Operands[2] = V.getOperand(0);
+    return true;
+  };
+
+  // We need to find another VP_ADD in one of the operands.
+  SDValue LHS0 = LHS.getOperand(0);
+  SDValue LHS1 = LHS.getOperand(1);
+  if (!FindAdd(LHS0, LHS1) && !FindAdd(LHS1, LHS0))
+    return SDValue();
+
+  // Now we have three operands of two additions. Check that one of them is a
+  // constant vector with ones.
+  auto I = llvm::find_if(Operands,
+                         [](const SDValue &Op) { return isOneOrOneSplat(Op); });
----------------
lukel97 wrote:

Can you pass the predicate directly?
```suggestion
  auto I = llvm::find_if(Operands, isOneOrOneSplat);
```

https://github.com/llvm/llvm-project/pull/124848


More information about the llvm-commits mailing list