[llvm] [RISCV] Optimize divide by constant for VP intrinsics (PR #125991)
Jesse Huang via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 26 00:26:37 PST 2025
================
@@ -27219,6 +27278,233 @@ SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitVPUDIV(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
+ // fold (vp.udiv X, -1) -> vp.select(X == -1, 1, 0)
+ if (N1C && N1C->isAllOnes()) {
+ EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ VT.getVectorElementCount());
+ return DAG.getNode(ISD::VP_SELECT, DL, VT,
+ DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+ DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+ VL);
+ }
+
+ if (SDValue V = visitVPUDIVLike(N0, N1, N)) {
+ // If the corresponding remainder node exists, update its users with
+ // (Dividend - (Quotient * Divisor).
+ if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_UREM, N->getVTList(),
+ {N0, N1, Mask, VL})) {
+ SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+ SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+ AddToWorklist(Mul.getNode());
+ AddToWorklist(Sub.getNode());
+ CombineTo(RemNode, Sub);
+ }
+ return V;
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+ SDLoc DL(N);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+ EVT VT = N->getValueType(0);
+
+ // fold (vp.udiv x, (1 << c)) -> vp.lshr(x, c)
+ if (isConstantOrConstantVector(N1, /*NoOpaques=*/true) &&
+ DAG.isKnownToBeAPowerOfTwo(N1)) {
+ SDValue LogBase2 = BuildLogBase2(N1, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+ AddToWorklist(Trunc.getNode());
+ return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Trunc, Mask, VL);
+ }
+
+ // fold (vp.udiv x, (vp.shl c, y)) -> vp.lshr(x, vp.add(log2(c)+y)) iff c is
+ // power of 2
+ if (N1.getOpcode() == ISD::VP_SHL && N1->getOperand(2) == Mask &&
+ N1->getOperand(3) == VL) {
+ SDValue N10 = N1.getOperand(0);
+ if (isConstantOrConstantVector(N10, /*NoOpaques=*/true) &&
+ DAG.isKnownToBeAPowerOfTwo(N10)) {
+ SDValue LogBase2 = BuildLogBase2(N10, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ADDVT = N1.getOperand(1).getValueType();
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+ AddToWorklist(Trunc.getNode());
+ SDValue Add = DAG.getNode(ISD::VP_ADD, DL, ADDVT, N1.getOperand(1), Trunc,
+ Mask, VL);
+ AddToWorklist(Add.getNode());
+ return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Add, Mask, VL);
+ }
+ }
+
+ // fold (vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c is
+ // power of 2
+ if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+ SDValue N10 = N1.getOperand(0);
+ if (N10.getOpcode() == ISD::SHL) {
+ SDValue N0SHL = N10.getOperand(0);
+ if (isa<ConstantSDNode>(N0SHL) && DAG.isKnownToBeAPowerOfTwo(N0SHL)) {
+ SDValue LogBase2 = BuildLogBase2(N0SHL, DL);
+ AddToWorklist(LogBase2.getNode());
+
+ EVT ADDVT = N10.getOperand(1).getValueType();
+ SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+ AddToWorklist(Trunc.getNode());
+ SDValue Add =
+ DAG.getNode(ISD::ADD, DL, ADDVT, N10.getOperand(1), Trunc);
----------------
jaidTw wrote:
This folds `(vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c ispower of 2`, where `c` and `y` are scalars, so it's intended
https://github.com/llvm/llvm-project/pull/125991
More information about the llvm-commits
mailing list