[llvm] [RISCV] Optimize divide by constant for VP intrinsics (PR #125991)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 25 14:50:43 PST 2025
================
@@ -27219,6 +27278,233 @@ SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitVPUDIV(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
+ // fold (vp.udiv X, -1) -> vp.select(X == -1, 1, 0)
+ if (N1C && N1C->isAllOnes()) {
+ EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ VT.getVectorElementCount());
+ return DAG.getNode(ISD::VP_SELECT, DL, VT,
+ DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+ DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+ VL);
+ }
+
+ if (SDValue V = visitVPUDIVLike(N0, N1, N)) {
+ // If the corresponding remainder node exists, update its users with
+ // (Dividend - (Quotient * Divisor).
+ if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_UREM, N->getVTList(),
+ {N0, N1, Mask, VL})) {
+ SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+ SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+ AddToWorklist(Mul.getNode());
+ AddToWorklist(Sub.getNode());
+ CombineTo(RemNode, Sub);
+ }
+ return V;
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+ SDLoc DL(N);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+ EVT VT = N->getValueType(0);
+
+ // fold (vp.udiv x, (1 << c)) -> vp.lshr(x, c)
+ if (isConstantOrConstantVector(N1, /*NoOpaques=*/true) &&
+ DAG.isKnownToBeAPowerOfTwo(N1)) {
+ SDValue LogBase2 = BuildLogBase2(N1, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+ AddToWorklist(Trunc.getNode());
+ return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Trunc, Mask, VL);
+ }
+
+ // fold (vp.udiv x, (vp.shl c, y)) -> vp.lshr(x, vp.add(log2(c)+y)) iff c is
+ // power of 2
+ if (N1.getOpcode() == ISD::VP_SHL && N1->getOperand(2) == Mask &&
+ N1->getOperand(3) == VL) {
+ SDValue N10 = N1.getOperand(0);
+ if (isConstantOrConstantVector(N10, /*NoOpaques=*/true) &&
+ DAG.isKnownToBeAPowerOfTwo(N10)) {
+ SDValue LogBase2 = BuildLogBase2(N10, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ADDVT = N1.getOperand(1).getValueType();
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+ AddToWorklist(Trunc.getNode());
+ SDValue Add = DAG.getNode(ISD::VP_ADD, DL, ADDVT, N1.getOperand(1), Trunc,
+ Mask, VL);
+ AddToWorklist(Add.getNode());
+ return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Add, Mask, VL);
+ }
+ }
+
+ // fold (vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c is
+ // power of 2
+ if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+ SDValue N10 = N1.getOperand(0);
+ if (N10.getOpcode() == ISD::SHL) {
+ SDValue N0SHL = N10.getOperand(0);
+ if (isa<ConstantSDNode>(N0SHL) && DAG.isKnownToBeAPowerOfTwo(N0SHL)) {
+ SDValue LogBase2 = BuildLogBase2(N0SHL, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ADDVT = N10.getOperand(1).getValueType();
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+ AddToWorklist(Trunc.getNode());
+ SDValue Add =
+ DAG.getNode(ISD::ADD, DL, ADDVT, N10.getOperand(1), Trunc);
+ AddToWorklist(Add.getNode());
+ SDValue Splat = DAG.getSplatVector(VT, DL, Add);
+ AddToWorklist(Splat.getNode());
+ return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Splat, Mask, VL);
+ }
+ }
+ }
+
+ // fold (udiv x, c) -> alternate
+ AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+ if (isConstantOrConstantVector(N1) &&
+ !TLI.isIntDivCheap(N->getValueType(0), Attr)) {
+ if (SDValue Op = BuildUDIV(N))
+ return Op;
+ }
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitVPSDIV(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ // fold (vp.sdiv X, -1) -> 0-X
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
+ if (N1C && N1C->isAllOnes())
+ return DAG.getNode(ISD::VP_SUB, DL, VT, DAG.getConstant(0, DL, VT), N0,
+ Mask, VL);
+
+ // fold (vp.sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
+ if (N1C && N1C->getAPIntValue().isMinSignedValue()) {
+ EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ VT.getVectorElementCount());
+ return DAG.getNode(ISD::VP_SELECT, DL, VT,
+ DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+ DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+ VL);
+ }
+
+ // If we know the sign bits of both operands are zero, strength reduce to a
+ // vp.udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
+ if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
+ return DAG.getNode(ISD::VP_UDIV, DL, N1.getValueType(), N0, N1, Mask, VL);
+
+ if (SDValue V = visitVPSDIVLike(N0, N1, N)) {
+ // If the corresponding remainder node exists, update its users with
+ // (Dividend - (Quotient * Divisor).
+ if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_SREM, N->getVTList(),
+ {N0, N1, Mask, VL})) {
+ SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+ SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+ AddToWorklist(Mul.getNode());
+ AddToWorklist(Sub.getNode());
+ CombineTo(RemNode, Sub);
+ }
+ return V;
+ }
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitVPSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+ SDLoc DL(N);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+ EVT VT = N->getValueType(0);
+ unsigned BitWidth = VT.getScalarSizeInBits();
+
+ // fold (vp.sdiv X, V of pow 2)
+ if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
+ isDivisorPowerOfTwo(N1.getOperand(0))) {
+ // Create constants that are functions of the shift amount value.
+ SDValue N = N1.getOperand(0);
+ EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ VT.getVectorElementCount());
+ EVT ScalarShiftAmtTy =
+ getShiftAmountTy(N0.getValueType().getVectorElementType());
+ SDValue Bits = DAG.getConstant(BitWidth, DL, ScalarShiftAmtTy);
+ SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT.getVectorElementType(), N);
+ C1 = DAG.getZExtOrTrunc(C1, DL, ScalarShiftAmtTy);
+ SDValue Inexact = DAG.getNode(ISD::SUB, DL, ScalarShiftAmtTy, Bits, C1);
+ if (!isa<ConstantSDNode>(Inexact))
+ return SDValue();
+
+ // Splat the sign bit into the register
+ EVT VecShiftAmtTy = EVT::getVectorVT(*DAG.getContext(), ScalarShiftAmtTy,
+ VT.getVectorElementCount());
+ SDValue Sign =
+ DAG.getNode(ISD::VP_SRA, DL, VT, N0,
+ DAG.getConstant(BitWidth - 1, DL, VecShiftAmtTy), Mask, VL);
+ AddToWorklist(Sign.getNode());
+
+ // Add N0, ((N0 < 0) ? abs(N1) - 1 : 0);
+ Inexact = DAG.getSplat(VT, DL, Inexact);
+ C1 = DAG.getSplat(VT, DL, C1);
+ SDValue Srl = DAG.getNode(ISD::VP_SRL, DL, VT, Sign, Inexact, Mask, VL);
+ AddToWorklist(Srl.getNode());
+ SDValue Add = DAG.getNode(ISD::VP_ADD, DL, VT, N0, Srl, Mask, VL);
+ AddToWorklist(Add.getNode());
+ SDValue Sra = DAG.getNode(ISD::VP_SRA, DL, VT, Add, C1, Mask, VL);
+ AddToWorklist(Sra.getNode());
+
+ // Special case: (sdiv X, 1) -> X
+ // Special Case: (sdiv X, -1) -> 0-X
+ SDValue One = DAG.getConstant(1, DL, VT);
+ SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
+ SDValue IsOne = DAG.getSetCCVP(DL, CCVT, N1, One, ISD::SETEQ, Mask, VL);
+ SDValue IsAllOnes =
+ DAG.getSetCCVP(DL, CCVT, N1, AllOnes, ISD::SETEQ, Mask, VL);
+ SDValue IsOneOrAllOnes =
+ DAG.getNode(ISD::VP_OR, DL, CCVT, IsOne, IsAllOnes, Mask, VL);
+ Sra = DAG.getNode(ISD::VP_SELECT, DL, VT, IsOneOrAllOnes, N0, Sra, VL);
+
+ // If dividing by a positive value, we're done. Otherwise, the result must
+ // be negated.
+ SDValue Zero = DAG.getConstant(0, DL, VT);
+ SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, Zero, Sra, Mask, VL);
+
+ // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
----------------
topperc wrote:
I don't think this FIXME applies since we don't have a VP_SELECT_CC
https://github.com/llvm/llvm-project/pull/125991
More information about the llvm-commits
mailing list