[llvm] [RISCV] Optimize divide by constant for VP intrinsics (PR #125991)

Jesse Huang via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 26 00:26:37 PST 2025


================
@@ -27219,6 +27278,233 @@ SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitVPUDIV(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
+  // fold (vp.udiv X, -1) -> vp.select(X == -1, 1, 0)
+  if (N1C && N1C->isAllOnes()) {
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    return DAG.getNode(ISD::VP_SELECT, DL, VT,
+                       DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+                       DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+                       VL);
+  }
+
+  if (SDValue V = visitVPUDIVLike(N0, N1, N)) {
+    // If the corresponding remainder node exists, update its users with
+    // (Dividend - (Quotient * Divisor).
+    if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_UREM, N->getVTList(),
+                                              {N0, N1, Mask, VL})) {
+      SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+      SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+      AddToWorklist(Mul.getNode());
+      AddToWorklist(Sub.getNode());
+      CombineTo(RemNode, Sub);
+    }
+    return V;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+  SDLoc DL(N);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+
+  // fold (vp.udiv x, (1 << c)) -> vp.lshr(x, c)
+  if (isConstantOrConstantVector(N1, /*NoOpaques=*/true) &&
+      DAG.isKnownToBeAPowerOfTwo(N1)) {
+    SDValue LogBase2 = BuildLogBase2(N1, DL);
+    AddToWorklist(LogBase2.getNode());
+
+    EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+    SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+    AddToWorklist(Trunc.getNode());
+    return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Trunc, Mask, VL);
+  }
+
+  // fold (vp.udiv x, (vp.shl c, y)) -> vp.lshr(x, vp.add(log2(c)+y)) iff c is
+  // power of 2
+  if (N1.getOpcode() == ISD::VP_SHL && N1->getOperand(2) == Mask &&
+      N1->getOperand(3) == VL) {
+    SDValue N10 = N1.getOperand(0);
+    if (isConstantOrConstantVector(N10, /*NoOpaques=*/true) &&
+        DAG.isKnownToBeAPowerOfTwo(N10)) {
+      SDValue LogBase2 = BuildLogBase2(N10, DL);
+      AddToWorklist(LogBase2.getNode());
+
+      EVT ADDVT = N1.getOperand(1).getValueType();
+      SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+      AddToWorklist(Trunc.getNode());
+      SDValue Add = DAG.getNode(ISD::VP_ADD, DL, ADDVT, N1.getOperand(1), Trunc,
+                                Mask, VL);
+      AddToWorklist(Add.getNode());
+      return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Add, Mask, VL);
+    }
+  }
+
+  // fold (vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c is
+  // power of 2
+  if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+    SDValue N10 = N1.getOperand(0);
+    if (N10.getOpcode() == ISD::SHL) {
+      SDValue N0SHL = N10.getOperand(0);
+      if (isa<ConstantSDNode>(N0SHL) && DAG.isKnownToBeAPowerOfTwo(N0SHL)) {
+        SDValue LogBase2 = BuildLogBase2(N0SHL, DL);
+        AddToWorklist(LogBase2.getNode());
+
+        EVT ADDVT = N10.getOperand(1).getValueType();
+        SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+        AddToWorklist(Trunc.getNode());
+        SDValue Add =
+            DAG.getNode(ISD::ADD, DL, ADDVT, N10.getOperand(1), Trunc);
----------------
jaidTw wrote:

This folds `(vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c ispower of 2`, where `c` and `y` are scalars, so it's intended

https://github.com/llvm/llvm-project/pull/125991


More information about the llvm-commits mailing list