[llvm] Revert "[APInt] Remove multiplicativeInverse with explicit modulus (#… (PR #87812)

Jeremy Kun via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 5 10:57:33 PDT 2024


https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/87812

…87644)"

This reverts commit 0b293e8c36d97bbd7f85ed5b67ce510ff7fd86ee.

There are out-of-tree uses of this method, and it is planned to be used as part of a new polynomial dialect in MLIR, a starting PR of which is https://github.com/llvm/llvm-project/pull/72081 (later PRs will add lowerings that need the removed functionality)

>From 02329106f64c176507512202bb25ffb3cc6767f4 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 5 Apr 2024 10:55:01 -0700
Subject: [PATCH] Revert "[APInt] Remove multiplicativeInverse with explicit
 modulus (#87644)"

This reverts commit 0b293e8c36d97bbd7f85ed5b67ce510ff7fd86ee.

There are out-of-tree uses of this method, and it is planned to be used
as part of a new polynomial dialect in MLIR, a starting PR of which is
https://github.com/llvm/llvm-project/pull/72081 (later PRs will add
lowerings that need the removed functionality)
---
 llvm/include/llvm/ADT/APInt.h    |  3 ++
 llvm/lib/Support/APInt.cpp       | 49 ++++++++++++++++++++++++++++++++
 llvm/unittests/ADT/APIntTest.cpp | 19 ++++++++++---
 3 files changed, 67 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 8d3c029b2e7e91..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1740,6 +1740,9 @@ class [[nodiscard]] APInt {
     return *this;
   }
 
+  /// \returns the multiplicative inverse for a given modulo.
+  APInt multiplicativeInverse(const APInt &modulo) const;
+
   /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
   APInt multiplicativeInverse() const;
 
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 224ea0924f0aaa..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1240,6 +1240,55 @@ APInt APInt::sqrt() const {
   return x_old + 1;
 }
 
+/// Computes the multiplicative inverse of this APInt for a given modulo. The
+/// iterative extended Euclidean algorithm is used to solve for this value,
+/// however we simplify it to speed up calculating only the inverse, and take
+/// advantage of div+rem calculations. We also use some tricks to avoid copying
+/// (potentially large) APInts around.
+/// WARNING: a value of '0' may be returned,
+///          signifying that no multiplicative inverse exists!
+APInt APInt::multiplicativeInverse(const APInt& modulo) const {
+  assert(ult(modulo) && "This APInt must be smaller than the modulo");
+
+  // Using the properties listed at the following web page (accessed 06/21/08):
+  //   http://www.numbertheory.org/php/euclid.html
+  // (especially the properties numbered 3, 4 and 9) it can be proved that
+  // BitWidth bits suffice for all the computations in the algorithm implemented
+  // below. More precisely, this number of bits suffice if the multiplicative
+  // inverse exists, but may not suffice for the general extended Euclidean
+  // algorithm.
+
+  APInt r[2] = { modulo, *this };
+  APInt t[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
+  APInt q(BitWidth, 0);
+
+  unsigned i;
+  for (i = 0; r[i^1] != 0; i ^= 1) {
+    // An overview of the math without the confusing bit-flipping:
+    // q = r[i-2] / r[i-1]
+    // r[i] = r[i-2] % r[i-1]
+    // t[i] = t[i-2] - t[i-1] * q
+    udivrem(r[i], r[i^1], q, r[i]);
+    t[i] -= t[i^1] * q;
+  }
+
+  // If this APInt and the modulo are not coprime, there is no multiplicative
+  // inverse, so return 0. We check this by looking at the next-to-last
+  // remainder, which is the gcd(*this,modulo) as calculated by the Euclidean
+  // algorithm.
+  if (r[i] != 1)
+    return APInt(BitWidth, 0);
+
+  // The next-to-last t is the multiplicative inverse.  However, we are
+  // interested in a positive inverse. Calculate a positive one from a negative
+  // one if necessary. A simple addition of the modulo suffices because
+  // abs(t[i]) is known to be less than *this/2 (see the link above).
+  if (t[i].isNegative())
+    t[i] += modulo;
+
+  return std::move(t[i]);
+}
+
 /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
 APInt APInt::multiplicativeInverse() const {
   assert((*this)[0] &&
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 76fc26412407e7..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3249,11 +3249,22 @@ TEST(APIntTest, SolveQuadraticEquationWrap) {
 }
 
 TEST(APIntTest, MultiplicativeInverseExaustive) {
-  for (unsigned BitWidth = 1; BitWidth <= 8; ++BitWidth) {
-    for (unsigned Value = 1; Value < (1u << BitWidth); Value += 2) {
-      // Multiplicative inverse exists for all odd numbers.
+  for (unsigned BitWidth = 1; BitWidth <= 16; ++BitWidth) {
+    for (unsigned Value = 0; Value < (1u << BitWidth); ++Value) {
       APInt V = APInt(BitWidth, Value);
-      EXPECT_EQ(V * V.multiplicativeInverse(), 1);
+      APInt MulInv =
+          V.zext(BitWidth + 1)
+              .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
+              .trunc(BitWidth);
+      APInt One = V * MulInv;
+      if (V[0]) {
+        // Multiplicative inverse exists for all odd numbers.
+        EXPECT_TRUE(One.isOne());
+        EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
+      } else {
+        // Multiplicative inverse does not exist for even numbers (and 0).
+        EXPECT_TRUE(MulInv.isZero());
+      }
     }
   }
 }



More information about the llvm-commits mailing list