[llvm] 1b76120 - [APInt] Add a simpler overload of multiplicativeInverse (#87610)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 4 08:11:10 PDT 2024


Author: Jay Foad
Date: 2024-04-04T16:11:06+01:00
New Revision: 1b761205f2686516cebadbcbc37f798197d9c482

URL: https://github.com/llvm/llvm-project/commit/1b761205f2686516cebadbcbc37f798197d9c482
DIFF: https://github.com/llvm/llvm-project/commit/1b761205f2686516cebadbcbc37f798197d9c482.diff

LOG: [APInt] Add a simpler overload of multiplicativeInverse (#87610)

The current APInt::multiplicativeInverse takes a modulus which can be
any value, but all in-tree callers use a power of two. Moreover, most
callers want to use two to the power of the width of an existing APInt,
which is awkward because 2^N is not representable as an N-bit APInt.

Add a new overload of multiplicativeInverse which implicitly uses
2^BitWidth as the modulus.

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APInt.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Support/APInt.cpp
    llvm/unittests/ADT/APIntTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index b9b39f3b9dfbc4..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1743,6 +1743,9 @@ class [[nodiscard]] APInt {
   /// \returns the multiplicative inverse for a given modulo.
   APInt multiplicativeInverse(const APInt &modulo) const;
 
+  /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
+  APInt multiplicativeInverse() const;
+
   /// @}
   /// \name Building-block Operations for APInt and APFloat
   /// @{

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 515b9d0744f6e3..e030b9fc7dac4f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -944,10 +944,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
   // Calculate the multiplicative inverse of K! / 2^T;
   // this multiplication factor will perform the exact division by
   // K! / 2^T.
-  APInt Mod = APInt::getSignedMinValue(W+1);
-  APInt MultiplyFactor = OddFactorial.zext(W+1);
-  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
-  MultiplyFactor = MultiplyFactor.trunc(W);
+  APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
 
   // Calculate the product, at width T+W
   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
@@ -10086,10 +10083,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
   // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
   // (N / D) in general. The inverse itself always fits into BW bits, though,
   // so we immediately truncate it.
-  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
-  APInt Mod(BW + 1, 0);
-  Mod.setBit(BW - Mult2);  // Mod = N / D
-  APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
+  APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
+  APInt I = AD.multiplicativeInverse().zext(BW);
 
   // 4. Compute the minimum unsigned root of the equation:
   // I * (B / D) mod (N / D)

diff  --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 062132c8304b06..719209e0edd5fb 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5201,10 +5201,7 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
 
     // Calculate the multiplicative inverse modulo BW.
     // 2^W requires W + 1 bits, so we have to extend and then truncate.
-    unsigned W = Divisor.getBitWidth();
-    APInt Factor = Divisor.zext(W + 1)
-                       .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
-                       .trunc(W);
+    APInt Factor = Divisor.multiplicativeInverse();
     Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
     Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
     return true;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 5e053f97675d7f..409d66adfd67d1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -6071,11 +6071,7 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
       Divisor.ashrInPlace(Shift);
       UseSRA = true;
     }
-    // Calculate the multiplicative inverse, using Newton's method.
-    APInt t;
-    APInt Factor = Divisor;
-    while ((t = Divisor * Factor) != 1)
-      Factor *= APInt(Divisor.getBitWidth(), 2) - t;
+    APInt Factor = Divisor.multiplicativeInverse();
     Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
     Factors.push_back(DAG.getConstant(Factor, dl, SVT));
     return true;
@@ -6664,10 +6660,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
     // P = inv(D0, 2^W)
     // 2^W requires W + 1 bits, so we have to extend and then truncate.
     unsigned W = D.getBitWidth();
-    APInt P = D0.zext(W + 1)
-                  .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
-                  .trunc(W);
-    assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
+    APInt P = D0.multiplicativeInverse();
     assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
 
     // Q = floor((2^W - 1) u/ D)
@@ -6922,10 +6915,7 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
     // P = inv(D0, 2^W)
     // 2^W requires W + 1 bits, so we have to extend and then truncate.
     unsigned W = D.getBitWidth();
-    APInt P = D0.zext(W + 1)
-                  .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
-                  .trunc(W);
-    assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
+    APInt P = D0.multiplicativeInverse();
     assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
 
     // A = floor((2^(W - 1) - 1) / D0) & -2^K
@@ -7651,7 +7641,7 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
 //
 // For division, we can compute the remainder using the algorithm described
 // above, subtract it from the dividend to get an exact multiple of Constant.
-// Then multiply that extact multiply by the multiplicative inverse modulo
+// Then multiply that exact multiply by the multiplicative inverse modulo
 // (1 << (BitWidth / 2)) to get the quotient.
 
 // If Constant is even, we can shift right the dividend and the divisor by the
@@ -7786,10 +7776,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
 
     // Multiply by the multiplicative inverse of the divisor modulo
     // (1 << BitWidth).
-    APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
-    APInt MulFactor = Divisor.zext(BitWidth + 1);
-    MulFactor = MulFactor.multiplicativeInverse(Mod);
-    MulFactor = MulFactor.trunc(BitWidth);
+    APInt MulFactor = Divisor.multiplicativeInverse();
 
     SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
                                    DAG.getConstant(MulFactor, dl, VT));

diff  --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index c20609748dc97c..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1289,6 +1289,19 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const {
   return std::move(t[i]);
 }
 
+/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
+APInt APInt::multiplicativeInverse() const {
+  assert((*this)[0] &&
+         "multiplicative inverse is only defined for odd numbers!");
+
+  // Use Newton's method.
+  APInt Factor = *this;
+  APInt T;
+  while (!(T = *this * Factor).isOne())
+    Factor *= 2 - T;
+  return Factor;
+}
+
 /// Implementation of Knuth's Algorithm D (Division of nonnegative integers)
 /// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The
 /// variables here have the same names as in the algorithm. Comments explain

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index d5ef63e38e2790..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3257,9 +3257,10 @@ TEST(APIntTest, MultiplicativeInverseExaustive) {
               .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
               .trunc(BitWidth);
       APInt One = V * MulInv;
-      if (!V.isZero() && V.countr_zero() == 0) {
+      if (V[0]) {
         // Multiplicative inverse exists for all odd numbers.
         EXPECT_TRUE(One.isOne());
+        EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
       } else {
         // Multiplicative inverse does not exist for even numbers (and 0).
         EXPECT_TRUE(MulInv.isZero());


        


More information about the llvm-commits mailing list