[llvm] [RISCV] Add DAG combine for forming VAADDU_VL from VP intrinsics. (PR #124848)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 09:01:03 PST 2025
================
@@ -16373,6 +16371,101 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
VPStore->isTruncatingStore(), VPStore->isCompressingStore());
}
+// Peephole avgceil pattern.
+// %1 = zext <N x i8> %a to <N x i32>
+// %2 = zext <N x i8> %b to <N x i32>
+// %3 = add nuw nsw <N x i32> %1, splat (i32 1)
+// %4 = add nuw nsw <N x i32> %3, %2
+// %5 = lshr <N x i32> %N, <i32 1 x N>
+// %6 = trunc <N x i32> %5 to <N x i8>
+static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ EVT VT = N->getValueType(0);
+
+ // Ignore fixed vectors.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!VT.isScalableVector() || !TLI.isTypeLegal(VT))
+ return SDValue();
+
+ SDValue In = N->getOperand(0);
+ SDValue Mask = N->getOperand(1);
+ SDValue VL = N->getOperand(2);
+
+ // Input should be a vp_srl with same mask and VL.
+ if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask ||
+ In.getOperand(3) != VL)
+ return SDValue();
+
+ // Shift amount should be 1.
+ if (!isOneOrOneSplat(In.getOperand(1)))
+ return SDValue();
+
+ // Shifted value should be a vp_add with same mask and VL.
+ SDValue LHS = In.getOperand(0);
+ if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask ||
+ LHS.getOperand(3) != VL)
+ return SDValue();
+
+ SDValue Operands[3];
+ Operands[0] = LHS.getOperand(0);
+ Operands[1] = LHS.getOperand(1);
+
+ // Matches another VP_ADD with same VL and Mask.
+ auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) {
+ if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
+ V.getOperand(3) != VL)
+ return false;
+
+ Op0 = V.getOperand(0);
+ Op1 = V.getOperand(1);
+ return true;
+ };
+
+ // We need to find another VP_ADD in one of the operands.
+ SDValue Op0, Op1;
+ if (FindAdd(Operands[0], Op0, Op1))
+ Operands[0] = Operands[1];
+ else if (!FindAdd(Operands[1], Op0, Op1))
+ return SDValue();
+ Operands[2] = Op0;
+ Operands[1] = Op1;
+
+ // Now we have three operands of two additions. Check that one of them is a
+ // constant vector with ones.
+ auto I = llvm::find_if(Operands,
+ [](const SDValue &Op) { return isOneOrOneSplat(Op); });
+ if (I == std::end(Operands))
+ return SDValue();
+ // We found a vector with ones, move if it to the end of the Operands array.
+ std::swap(Operands[I - std::begin(Operands)], Operands[2]);
+
+ // Make sure the other 2 operands can be promoted from the result type.
+ for (int i = 0; i < 2; ++i) {
+ if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND ||
+ Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL)
+ return SDValue();
+ // Input must be smaller than our result.
+ if (Operands[i].getOperand(0).getScalarValueSizeInBits() >
+ VT.getScalarSizeInBits())
+ return SDValue();
+ }
+
+ // Pattern is detected.
+ Op0 = Operands[0].getOperand(0);
+ Op1 = Operands[1].getOperand(0);
+ // Rebuild the zero extends if the inputs are smaller than our result.
+ if (Op0.getValueType() != VT)
+ Op0 =
+ DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, Op0, Mask, VL);
+ if (Op1.getValueType() != VT)
----------------
preames wrote:
Surely, we short circuit these on construction and can just unconditionally call the getNode API?
https://github.com/llvm/llvm-project/pull/124848
More information about the llvm-commits
mailing list