[llvm] [RISCV] Optimize divide by constant for VP intrinsics (PR #125991)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 25 14:50:43 PST 2025


================
@@ -27219,6 +27278,233 @@ SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitVPUDIV(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
+  // fold (vp.udiv X, -1) -> vp.select(X == -1, 1, 0)
+  if (N1C && N1C->isAllOnes()) {
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    return DAG.getNode(ISD::VP_SELECT, DL, VT,
+                       DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+                       DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+                       VL);
+  }
+
+  if (SDValue V = visitVPUDIVLike(N0, N1, N)) {
+    // If the corresponding remainder node exists, update its users with
+    // (Dividend - (Quotient * Divisor).
+    if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_UREM, N->getVTList(),
+                                              {N0, N1, Mask, VL})) {
+      SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+      SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+      AddToWorklist(Mul.getNode());
+      AddToWorklist(Sub.getNode());
+      CombineTo(RemNode, Sub);
+    }
+    return V;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+  SDLoc DL(N);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+
+  // fold (vp.udiv x, (1 << c)) -> vp.lshr(x, c)
+  if (isConstantOrConstantVector(N1, /*NoOpaques=*/true) &&
+      DAG.isKnownToBeAPowerOfTwo(N1)) {
+    SDValue LogBase2 = BuildLogBase2(N1, DL);
+    AddToWorklist(LogBase2.getNode());
+
+    EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+    SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+    AddToWorklist(Trunc.getNode());
+    return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Trunc, Mask, VL);
+  }
+
+  // fold (vp.udiv x, (vp.shl c, y)) -> vp.lshr(x, vp.add(log2(c)+y)) iff c is
+  // power of 2
+  if (N1.getOpcode() == ISD::VP_SHL && N1->getOperand(2) == Mask &&
+      N1->getOperand(3) == VL) {
+    SDValue N10 = N1.getOperand(0);
+    if (isConstantOrConstantVector(N10, /*NoOpaques=*/true) &&
+        DAG.isKnownToBeAPowerOfTwo(N10)) {
+      SDValue LogBase2 = BuildLogBase2(N10, DL);
+      AddToWorklist(LogBase2.getNode());
+
+      EVT ADDVT = N1.getOperand(1).getValueType();
+      SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+      AddToWorklist(Trunc.getNode());
+      SDValue Add = DAG.getNode(ISD::VP_ADD, DL, ADDVT, N1.getOperand(1), Trunc,
+                                Mask, VL);
+      AddToWorklist(Add.getNode());
+      return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Add, Mask, VL);
+    }
+  }
+
+  // fold (vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c is
+  // power of 2
+  if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+    SDValue N10 = N1.getOperand(0);
+    if (N10.getOpcode() == ISD::SHL) {
+      SDValue N0SHL = N10.getOperand(0);
+      if (isa<ConstantSDNode>(N0SHL) && DAG.isKnownToBeAPowerOfTwo(N0SHL)) {
+        SDValue LogBase2 = BuildLogBase2(N0SHL, DL);
+        AddToWorklist(LogBase2.getNode());
+
+        EVT ADDVT = N10.getOperand(1).getValueType();
+        SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+        AddToWorklist(Trunc.getNode());
+        SDValue Add =
+            DAG.getNode(ISD::ADD, DL, ADDVT, N10.getOperand(1), Trunc);
+        AddToWorklist(Add.getNode());
+        SDValue Splat = DAG.getSplatVector(VT, DL, Add);
+        AddToWorklist(Splat.getNode());
+        return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Splat, Mask, VL);
+      }
+    }
+  }
+
+  // fold (udiv x, c) -> alternate
+  AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+  if (isConstantOrConstantVector(N1) &&
+      !TLI.isIntDivCheap(N->getValueType(0), Attr)) {
+    if (SDValue Op = BuildUDIV(N))
+      return Op;
+  }
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPSDIV(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  // fold (vp.sdiv X, -1) -> 0-X
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
+  if (N1C && N1C->isAllOnes())
+    return DAG.getNode(ISD::VP_SUB, DL, VT, DAG.getConstant(0, DL, VT), N0,
+                       Mask, VL);
+
+  // fold (vp.sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
+  if (N1C && N1C->getAPIntValue().isMinSignedValue()) {
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    return DAG.getNode(ISD::VP_SELECT, DL, VT,
+                       DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+                       DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+                       VL);
+  }
+
+  // If we know the sign bits of both operands are zero, strength reduce to a
+  // vp.udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
+  if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
+    return DAG.getNode(ISD::VP_UDIV, DL, N1.getValueType(), N0, N1, Mask, VL);
+
+  if (SDValue V = visitVPSDIVLike(N0, N1, N)) {
+    // If the corresponding remainder node exists, update its users with
+    // (Dividend - (Quotient * Divisor).
+    if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_SREM, N->getVTList(),
+                                              {N0, N1, Mask, VL})) {
+      SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+      SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+      AddToWorklist(Mul.getNode());
+      AddToWorklist(Sub.getNode());
+      CombineTo(RemNode, Sub);
+    }
+    return V;
+  }
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+  SDLoc DL(N);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  unsigned BitWidth = VT.getScalarSizeInBits();
+
+  // fold (vp.sdiv X, V of pow 2)
+  if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
+      isDivisorPowerOfTwo(N1.getOperand(0))) {
+    // Create constants that are functions of the shift amount value.
+    SDValue N = N1.getOperand(0);
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    EVT ScalarShiftAmtTy =
+        getShiftAmountTy(N0.getValueType().getVectorElementType());
+    SDValue Bits = DAG.getConstant(BitWidth, DL, ScalarShiftAmtTy);
+    SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT.getVectorElementType(), N);
+    C1 = DAG.getZExtOrTrunc(C1, DL, ScalarShiftAmtTy);
+    SDValue Inexact = DAG.getNode(ISD::SUB, DL, ScalarShiftAmtTy, Bits, C1);
+    if (!isa<ConstantSDNode>(Inexact))
+      return SDValue();
+
+    // Splat the sign bit into the register
+    EVT VecShiftAmtTy = EVT::getVectorVT(*DAG.getContext(), ScalarShiftAmtTy,
+                                         VT.getVectorElementCount());
+    SDValue Sign =
+        DAG.getNode(ISD::VP_SRA, DL, VT, N0,
+                    DAG.getConstant(BitWidth - 1, DL, VecShiftAmtTy), Mask, VL);
+    AddToWorklist(Sign.getNode());
+
+    // Add N0, ((N0 < 0) ? abs(N1) - 1 : 0);
+    Inexact = DAG.getSplat(VT, DL, Inexact);
+    C1 = DAG.getSplat(VT, DL, C1);
+    SDValue Srl = DAG.getNode(ISD::VP_SRL, DL, VT, Sign, Inexact, Mask, VL);
+    AddToWorklist(Srl.getNode());
+    SDValue Add = DAG.getNode(ISD::VP_ADD, DL, VT, N0, Srl, Mask, VL);
+    AddToWorklist(Add.getNode());
+    SDValue Sra = DAG.getNode(ISD::VP_SRA, DL, VT, Add, C1, Mask, VL);
+    AddToWorklist(Sra.getNode());
+
+    // Special case: (sdiv X, 1) -> X
+    // Special Case: (sdiv X, -1) -> 0-X
+    SDValue One = DAG.getConstant(1, DL, VT);
+    SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
+    SDValue IsOne = DAG.getSetCCVP(DL, CCVT, N1, One, ISD::SETEQ, Mask, VL);
+    SDValue IsAllOnes =
+        DAG.getSetCCVP(DL, CCVT, N1, AllOnes, ISD::SETEQ, Mask, VL);
+    SDValue IsOneOrAllOnes =
+        DAG.getNode(ISD::VP_OR, DL, CCVT, IsOne, IsAllOnes, Mask, VL);
+    Sra = DAG.getNode(ISD::VP_SELECT, DL, VT, IsOneOrAllOnes, N0, Sra, VL);
+
+    // If dividing by a positive value, we're done. Otherwise, the result must
+    // be negated.
+    SDValue Zero = DAG.getConstant(0, DL, VT);
+    SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, Zero, Sra, Mask, VL);
+
+    // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
----------------
topperc wrote:

I don't think this FIXME applies since we don't have a VP_SELECT_CC

https://github.com/llvm/llvm-project/pull/125991


More information about the llvm-commits mailing list