[llvm] [RISCV] Add DAG combine for forming VAADDU_VL from VP intrinsics. (PR #124848)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 09:01:02 PST 2025
================
@@ -16373,6 +16371,101 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
VPStore->isTruncatingStore(), VPStore->isCompressingStore());
}
+// Peephole avgceil pattern.
+// %1 = zext <N x i8> %a to <N x i32>
+// %2 = zext <N x i8> %b to <N x i32>
+// %3 = add nuw nsw <N x i32> %1, splat (i32 1)
+// %4 = add nuw nsw <N x i32> %3, %2
+// %5 = lshr <N x i32> %N, <i32 1 x N>
+// %6 = trunc <N x i32> %5 to <N x i8>
+static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ EVT VT = N->getValueType(0);
+
+ // Ignore fixed vectors.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!VT.isScalableVector() || !TLI.isTypeLegal(VT))
+ return SDValue();
+
+ SDValue In = N->getOperand(0);
+ SDValue Mask = N->getOperand(1);
+ SDValue VL = N->getOperand(2);
+
+ // Input should be a vp_srl with same mask and VL.
+ if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask ||
+ In.getOperand(3) != VL)
+ return SDValue();
+
+ // Shift amount should be 1.
+ if (!isOneOrOneSplat(In.getOperand(1)))
+ return SDValue();
+
+ // Shifted value should be a vp_add with same mask and VL.
+ SDValue LHS = In.getOperand(0);
+ if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask ||
+ LHS.getOperand(3) != VL)
+ return SDValue();
+
+ SDValue Operands[3];
+ Operands[0] = LHS.getOperand(0);
+ Operands[1] = LHS.getOperand(1);
+
+ // Matches another VP_ADD with same VL and Mask.
+ auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) {
+ if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
+ V.getOperand(3) != VL)
+ return false;
+
+ Op0 = V.getOperand(0);
+ Op1 = V.getOperand(1);
+ return true;
+ };
+
+ // We need to find another VP_ADD in one of the operands.
+ SDValue Op0, Op1;
+ if (FindAdd(Operands[0], Op0, Op1))
+ Operands[0] = Operands[1];
+ else if (!FindAdd(Operands[1], Op0, Op1))
+ return SDValue();
+ Operands[2] = Op0;
+ Operands[1] = Op1;
+
+ // Now we have three operands of two additions. Check that one of them is a
+ // constant vector with ones.
+ auto I = llvm::find_if(Operands,
+ [](const SDValue &Op) { return isOneOrOneSplat(Op); });
+ if (I == std::end(Operands))
+ return SDValue();
+ // We found a vector with ones, move if it to the end of the Operands array.
+ std::swap(Operands[I - std::begin(Operands)], Operands[2]);
+
+ // Make sure the other 2 operands can be promoted from the result type.
+ for (int i = 0; i < 2; ++i) {
+ if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND ||
+ Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL)
+ return SDValue();
+ // Input must be smaller than our result.
+ if (Operands[i].getOperand(0).getScalarValueSizeInBits() >
----------------
preames wrote:
Zero extend should always be less than equal right? If so, is this check needed? And don't you need to check for the degenerate equal sized type case? Or do we simplify those at construction?
https://github.com/llvm/llvm-project/pull/124848
More information about the llvm-commits
mailing list