[llvm] [RISCV] Optimize divide by constant for VP intrinsics (PR #125991)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 25 14:50:44 PST 2025


================
@@ -6692,6 +6823,143 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
   return DAG.getSelect(dl, VT, IsOne, N0, Q);
 }
 
+/// Given an ISD::VP_UDIV node expressing a divide by constant,
+/// return a DAG expression to select that will generate the same value by
+/// multiplying by a magic number.
+/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
+SDValue TargetLowering::BuildVPUDIV(SDNode *N, SelectionDAG &DAG,
+                                    bool IsAfterLegalization,
+                                    SmallVectorImpl<SDNode *> &Created) const {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  EVT SVT = VT.getScalarType();
+  EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
+  EVT ShSVT = ShVT.getScalarType();
+  unsigned EltBits = VT.getScalarSizeInBits();
+
+  // Check to see if we can do this.
+  if (!isTypeLegal(VT) ||
+      !isOperationLegalOrCustom(ISD::VP_MULHU, VT, IsAfterLegalization))
+    return SDValue();
+
+  bool UseNPQ = false, UsePreShift = false, UsePostShift = false;
+
+  SmallVector<SDValue, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
+
+  auto BuildUDIVPattern = [&](ConstantSDNode *C) {
+    if (C->isZero())
+      return false;
+    // FIXME: We should use a narrower constant when the upper
+    // bits are known to be zero.
+    const APInt &Divisor = C->getAPIntValue();
+    SDValue PreShift, MagicFactor, NPQFactor, PostShift;
+
+    // Magic algorithm doesn't work for division by 1. We need to emit a select
+    // at the end.
+    if (Divisor.isOne()) {
+      PreShift = PostShift = DAG.getUNDEF(ShSVT);
+      MagicFactor = NPQFactor = DAG.getUNDEF(SVT);
+    } else {
+      UnsignedDivisionByConstantInfo magics =
+          UnsignedDivisionByConstantInfo::get(Divisor);
+
+      MagicFactor = DAG.getConstant(magics.Magic, DL, SVT);
+
+      assert(magics.PreShift < Divisor.getBitWidth() &&
+             "We shouldn't generate an undefined shift!");
+      assert(magics.PostShift < Divisor.getBitWidth() &&
+             "We shouldn't generate an undefined shift!");
+      assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
+      PreShift = DAG.getConstant(magics.PreShift, DL, ShSVT);
+      PostShift = DAG.getConstant(magics.PostShift, DL, ShSVT);
+      NPQFactor = DAG.getConstant(
+          magics.IsAdd ? APInt::getOneBitSet(EltBits, EltBits - 1)
+                       : APInt::getZero(EltBits),
+          DL, SVT);
+      UseNPQ |= magics.IsAdd;
+      UsePreShift |= magics.PreShift != 0;
+      UsePostShift |= magics.PostShift != 0;
+    }
+
+    PreShifts.push_back(PreShift);
+    MagicFactors.push_back(MagicFactor);
+    NPQFactors.push_back(NPQFactor);
+    PostShifts.push_back(PostShift);
+    return true;
+  };
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+
+  // Collect the shifts/magic values from each element.
+  if (!ISD::matchUnaryPredicate(N1, BuildUDIVPattern))
+    return SDValue();
+
+  SDValue PreShift, PostShift, MagicFactor, NPQFactor;
+  if (N1.getOpcode() == ISD::BUILD_VECTOR) {
+    PreShift = DAG.getBuildVector(ShVT, DL, PreShifts);
+    MagicFactor = DAG.getBuildVector(VT, DL, MagicFactors);
+    NPQFactor = DAG.getBuildVector(VT, DL, NPQFactors);
+    PostShift = DAG.getBuildVector(ShVT, DL, PostShifts);
+  } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+    assert(PreShifts.size() == 1 && MagicFactors.size() == 1 &&
+           NPQFactors.size() == 1 && PostShifts.size() == 1 &&
+           "Expected matchUnaryPredicate to return one for scalable vectors");
+    PreShift = DAG.getSplatVector(ShVT, DL, PreShifts[0]);
+    MagicFactor = DAG.getSplatVector(VT, DL, MagicFactors[0]);
+    NPQFactor = DAG.getSplatVector(VT, DL, NPQFactors[0]);
+    PostShift = DAG.getSplatVector(ShVT, DL, PostShifts[0]);
+  } else {
+    assert(isa<ConstantSDNode>(N1) && "Expected a constant");
----------------
topperc wrote:

This code should be unreachable for a VP opcode

https://github.com/llvm/llvm-project/pull/125991


More information about the llvm-commits mailing list