[llvm] [CodeGen] Emit a more efficient magic number multiplication for exact udivs (PR #87161)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 02:06:34 PDT 2024


================
@@ -6126,6 +6125,70 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
   return DAG.getNode(ISD::MUL, dl, VT, Res, Factor);
 }
 
+/// Given an exact UDIV by a constant, create a multiplication
+/// with the multiplicative inverse of the constant.
+static SDValue BuildExactUDIV(const TargetLowering &TLI, SDNode *N,
+                              const SDLoc &dl, SelectionDAG &DAG,
+                              SmallVectorImpl<SDNode *> &Created) {
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  EVT SVT = VT.getScalarType();
+  EVT ShVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
+  EVT ShSVT = ShVT.getScalarType();
+
+  bool UseSRL = false;
+  SmallVector<SDValue, 16> Shifts, Factors;
+
+  auto BuildUDIVPattern = [&](ConstantSDNode *C) {
+    if (C->isZero())
+      return false;
+    APInt Divisor = C->getAPIntValue();
+    unsigned Shift = Divisor.countr_zero();
+    if (Shift) {
+      Divisor.lshrInPlace(Shift);
+      UseSRL = true;
+    }
+    // Calculate the multiplicative inverse modulo BW.
+    APInt Factor = Divisor.multiplicativeInverse();
+    Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
+    Factors.push_back(DAG.getConstant(Factor, dl, SVT));
+    return true;
+  };
+
+  // Collect all magic values from the build vector.
+  if (!ISD::matchUnaryPredicate(Op1, BuildUDIVPattern))
+    return SDValue();
+
+  SDValue Shift, Factor;
+  if (Op1.getOpcode() == ISD::BUILD_VECTOR) {
+    Shift = DAG.getBuildVector(ShVT, dl, Shifts);
+    Factor = DAG.getBuildVector(VT, dl, Factors);
+  } else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) {
+    assert(Shifts.size() == 1 && Factors.size() == 1 &&
+           "Expected matchUnaryPredicate to return one element for scalable "
+           "vectors");
+    Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
+    Factor = DAG.getSplatVector(VT, dl, Factors[0]);
+  } else {
+    assert(isa<ConstantSDNode>(Op1) && "Expected a constant");
+    Shift = Shifts[0];
+    Factor = Factors[0];
+  }
+
+  SDValue Res = Op0;
+
+  // Shift the value upfront if it is even, so the LSB is one.
----------------
jayfoad wrote:

The comment is confusing. We actually shift Op0 if Op1 is even, and this does not guarantee that the LSB of Op0 is one. Maybe just remove it?

https://github.com/llvm/llvm-project/pull/87161


More information about the llvm-commits mailing list