[llvm] [APInt] Restore multiplicativeInverse with explicit modulus and better testing (PR #87812)
Jeremy Kun via llvm-commits
llvm-commits at lists.llvm.org
Sun Apr 7 16:19:10 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/87812
>From 02329106f64c176507512202bb25ffb3cc6767f4 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 5 Apr 2024 10:55:01 -0700
Subject: [PATCH 1/4] Revert "[APInt] Remove multiplicativeInverse with
explicit modulus (#87644)"
This reverts commit 0b293e8c36d97bbd7f85ed5b67ce510ff7fd86ee.
There are out-of-tree uses of this method, and it is planned to be used
as part of a new polynomial dialect in MLIR, a starting PR of which is
https://github.com/llvm/llvm-project/pull/72081 (later PRs will add
lowerings that need the removed functionality)
---
llvm/include/llvm/ADT/APInt.h | 3 ++
llvm/lib/Support/APInt.cpp | 49 ++++++++++++++++++++++++++++++++
llvm/unittests/ADT/APIntTest.cpp | 19 ++++++++++---
3 files changed, 67 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 8d3c029b2e7e91..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1740,6 +1740,9 @@ class [[nodiscard]] APInt {
return *this;
}
+ /// \returns the multiplicative inverse for a given modulo.
+ APInt multiplicativeInverse(const APInt &modulo) const;
+
/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
APInt multiplicativeInverse() const;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 224ea0924f0aaa..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1240,6 +1240,55 @@ APInt APInt::sqrt() const {
return x_old + 1;
}
+/// Computes the multiplicative inverse of this APInt for a given modulo. The
+/// iterative extended Euclidean algorithm is used to solve for this value,
+/// however we simplify it to speed up calculating only the inverse, and take
+/// advantage of div+rem calculations. We also use some tricks to avoid copying
+/// (potentially large) APInts around.
+/// WARNING: a value of '0' may be returned,
+/// signifying that no multiplicative inverse exists!
+APInt APInt::multiplicativeInverse(const APInt& modulo) const {
+ assert(ult(modulo) && "This APInt must be smaller than the modulo");
+
+ // Using the properties listed at the following web page (accessed 06/21/08):
+ // http://www.numbertheory.org/php/euclid.html
+ // (especially the properties numbered 3, 4 and 9) it can be proved that
+ // BitWidth bits suffice for all the computations in the algorithm implemented
+ // below. More precisely, this number of bits suffice if the multiplicative
+ // inverse exists, but may not suffice for the general extended Euclidean
+ // algorithm.
+
+ APInt r[2] = { modulo, *this };
+ APInt t[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
+ APInt q(BitWidth, 0);
+
+ unsigned i;
+ for (i = 0; r[i^1] != 0; i ^= 1) {
+ // An overview of the math without the confusing bit-flipping:
+ // q = r[i-2] / r[i-1]
+ // r[i] = r[i-2] % r[i-1]
+ // t[i] = t[i-2] - t[i-1] * q
+ udivrem(r[i], r[i^1], q, r[i]);
+ t[i] -= t[i^1] * q;
+ }
+
+ // If this APInt and the modulo are not coprime, there is no multiplicative
+ // inverse, so return 0. We check this by looking at the next-to-last
+ // remainder, which is the gcd(*this,modulo) as calculated by the Euclidean
+ // algorithm.
+ if (r[i] != 1)
+ return APInt(BitWidth, 0);
+
+ // The next-to-last t is the multiplicative inverse. However, we are
+ // interested in a positive inverse. Calculate a positive one from a negative
+ // one if necessary. A simple addition of the modulo suffices because
+ // abs(t[i]) is known to be less than *this/2 (see the link above).
+ if (t[i].isNegative())
+ t[i] += modulo;
+
+ return std::move(t[i]);
+}
+
/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
APInt APInt::multiplicativeInverse() const {
assert((*this)[0] &&
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 76fc26412407e7..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3249,11 +3249,22 @@ TEST(APIntTest, SolveQuadraticEquationWrap) {
}
TEST(APIntTest, MultiplicativeInverseExaustive) {
- for (unsigned BitWidth = 1; BitWidth <= 8; ++BitWidth) {
- for (unsigned Value = 1; Value < (1u << BitWidth); Value += 2) {
- // Multiplicative inverse exists for all odd numbers.
+ for (unsigned BitWidth = 1; BitWidth <= 16; ++BitWidth) {
+ for (unsigned Value = 0; Value < (1u << BitWidth); ++Value) {
APInt V = APInt(BitWidth, Value);
- EXPECT_EQ(V * V.multiplicativeInverse(), 1);
+ APInt MulInv =
+ V.zext(BitWidth + 1)
+ .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
+ .trunc(BitWidth);
+ APInt One = V * MulInv;
+ if (V[0]) {
+ // Multiplicative inverse exists for all odd numbers.
+ EXPECT_TRUE(One.isOne());
+ EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
+ } else {
+ // Multiplicative inverse does not exist for even numbers (and 0).
+ EXPECT_TRUE(MulInv.isZero());
+ }
}
}
}
>From 8ff652be8252107744b34808faec1a46fe6d9fef Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 5 Apr 2024 12:37:03 -0700
Subject: [PATCH 2/4] address review comments
---
llvm/include/llvm/ADT/APInt.h | 4 +-
llvm/lib/Support/APInt.cpp | 45 +++++++++++------------
llvm/unittests/ADT/APIntTest.cpp | 63 +++++++++++++++++++++++++-------
3 files changed, 74 insertions(+), 38 deletions(-)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index bd1716219ee5fc..fbf5cdefb7f488 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1740,8 +1740,8 @@ class [[nodiscard]] APInt {
return *this;
}
- /// \returns the multiplicative inverse for a given modulo.
- APInt multiplicativeInverse(const APInt &modulo) const;
+ /// \returns the multiplicative inverse for a given modulus.
+ APInt multiplicativeInverse(const APInt &Modulus) const;
/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
APInt multiplicativeInverse() const;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index f8f699f8f6ccd7..fcfc0fba01ec25 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1240,15 +1240,15 @@ APInt APInt::sqrt() const {
return x_old + 1;
}
-/// Computes the multiplicative inverse of this APInt for a given modulo. The
-/// iterative extended Euclidean algorithm is used to solve for this value,
+/// Computes the multiplicative inverse of this APInt for a given modululus,
+/// or returns 0 if no multiplicative inverse exists.
+///
+/// The iterative extended Euclidean algorithm is used to solve for this value,
/// however we simplify it to speed up calculating only the inverse, and take
/// advantage of div+rem calculations. We also use some tricks to avoid copying
/// (potentially large) APInts around.
-/// WARNING: a value of '0' may be returned,
-/// signifying that no multiplicative inverse exists!
-APInt APInt::multiplicativeInverse(const APInt& modulo) const {
- assert(ult(modulo) && "This APInt must be smaller than the modulo");
+APInt APInt::multiplicativeInverse(const APInt& Modulus) const {
+ assert(ult(Modulus) && "This APInt must be smaller than the modulus");
// Using the properties listed at the following web page (accessed 06/21/08):
// http://www.numbertheory.org/php/euclid.html
@@ -1257,36 +1257,35 @@ APInt APInt::multiplicativeInverse(const APInt& modulo) const {
// below. More precisely, this number of bits suffice if the multiplicative
// inverse exists, but may not suffice for the general extended Euclidean
// algorithm.
-
- APInt r[2] = { modulo, *this };
- APInt t[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
- APInt q(BitWidth, 0);
+ APInt R[2] = { Modulus, *this };
+ APInt T[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
+ APInt Q(BitWidth, 0);
unsigned i;
- for (i = 0; r[i^1] != 0; i ^= 1) {
+ for (i = 0; R[i^1] != 0; i ^= 1) {
// An overview of the math without the confusing bit-flipping:
- // q = r[i-2] / r[i-1]
- // r[i] = r[i-2] % r[i-1]
- // t[i] = t[i-2] - t[i-1] * q
- udivrem(r[i], r[i^1], q, r[i]);
- t[i] -= t[i^1] * q;
+ // Q = R[i-2] / R[i-1]
+ // R[i] = R[i-2] % R[i-1]
+ // T[i] = T[i-2] - T[i-1] * Q
+ udivrem(R[i], R[i^1], Q, R[i]);
+ T[i] -= T[i^1] * Q;
}
- // If this APInt and the modulo are not coprime, there is no multiplicative
+ // If this APInt and the modulus are not coprime, there is no multiplicative
// inverse, so return 0. We check this by looking at the next-to-last
- // remainder, which is the gcd(*this,modulo) as calculated by the Euclidean
+ // remainder, which is the gcd(*this, modulus) as calculated by the Euclidean
// algorithm.
- if (r[i] != 1)
+ if (R[i] != 1)
return APInt(BitWidth, 0);
// The next-to-last t is the multiplicative inverse. However, we are
// interested in a positive inverse. Calculate a positive one from a negative
- // one if necessary. A simple addition of the modulo suffices because
+ // one if necessary. A simple addition of the modulus suffices because
// abs(t[i]) is known to be less than *this/2 (see the link above).
- if (t[i].isNegative())
- t[i] += modulo;
+ if (T[i].isNegative())
+ T[i] += Modulus;
- return std::move(t[i]);
+ return std::move(T[i]);
}
/// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 23f9ee2d39c441..e15d3b9cdf3167 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3248,21 +3248,58 @@ TEST(APIntTest, SolveQuadraticEquationWrap) {
Iterate(i);
}
-TEST(APIntTest, MultiplicativeInverseExaustive) {
- for (unsigned BitWidth = 1; BitWidth <= 16; ++BitWidth) {
- for (unsigned Value = 0; Value < (1u << BitWidth); ++Value) {
+TEST(APIntTest, MultiplicativeInverseExaustivePowerOfTwo) {
+ for (unsigned BitWidth = 1; BitWidth <= 8; ++BitWidth) {
+ for (unsigned Value = 1; Value < (1u << BitWidth); Value += 2) {
+ // Multiplicative inverse exists for all odd numbers.
APInt V = APInt(BitWidth, Value);
- APInt MulInv =
- V.zext(BitWidth + 1)
- .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
- .trunc(BitWidth);
- APInt One = V * MulInv;
- if (V[0]) {
- // Multiplicative inverse exists for all odd numbers.
- EXPECT_TRUE(One.isOne());
- EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
+ EXPECT_EQ(V * V.multiplicativeInverse(), 1);
+ }
+ }
+}
+
+TEST(APIntTest, ModularMultiplicativeInverseSpecific) {
+ // Test a single modulus for all known inverses and non-inverses.
+ int BitWidth = 8;
+ APInt Modulus(BitWidth, 26);
+ int Values[12] = {1, 3, 5, 7, 9, 11, 15, 17, 19, 21, 23, 25};
+ int Inverses[12] = {1, 9, 21, 15, 3, 19, 7, 23, 11, 5, 17, 25};
+ int NonInvertibleElements[14] = {0, 2, 4, 6, 8, 10, 12,
+ 13, 14, 16, 18, 20, 22, 24};
+
+ for (size_t i = 0; i < 12; ++i) {
+ APInt V(BitWidth, Values[i]);
+ APInt Inv = V.multiplicativeInverse(Modulus);
+ EXPECT_EQ(Inv, Inverses[i]);
+ }
+
+ for (size_t i = 0; i < 14; ++i) {
+ APInt V(BitWidth, NonInvertibleElements[i]);
+ APInt Inv = V.multiplicativeInverse(Modulus);
+ EXPECT_EQ(Inv, 0);
+ }
+}
+
+TEST(APIntTest, ModularMultiplicativeInverseExaustive) {
+ // Test all moduli and all values up to 8 bits using a gcd test to determine
+ // if a multiplicative inverse exists.
+ int BitWidth = 8;
+ for (unsigned Modulus = 2; Modulus < (1u << BitWidth); ++Modulus) {
+ for (unsigned Value = 0; Value < Modulus; ++Value) {
+ APInt M(BitWidth, Modulus);
+ APInt V(BitWidth, Value);
+ EXPECT_TRUE(V.ult(M))
+ << "Expected " << V << " ult " << M << ", but it was not";
+ APInt MulInv = V.multiplicativeInverse(M);
+ if (APIntOps::GreatestCommonDivisor(V, M).isOne()) {
+ EXPECT_FALSE(MulInv.isZero());
+ // Multiplication verification must take place in a larger bit width
+ APInt Actual = (V.zext(2 * BitWidth) * MulInv.zext(2 * BitWidth))
+ .urem(M.zext(2 * BitWidth));
+ EXPECT_TRUE(Actual.isOne())
+ << "Expected " << V << " * " << MulInv << " = 1 mod " << M
+ << ", but it was " << Actual;
} else {
- // Multiplicative inverse does not exist for even numbers (and 0).
EXPECT_TRUE(MulInv.isZero());
}
}
>From f0cd987e0b24520becc2b7a95391741b1d9ef0d6 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 5 Apr 2024 13:16:16 -0700
Subject: [PATCH 3/4] clang-format
---
llvm/lib/Support/APInt.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index fcfc0fba01ec25..635bf8793ec230 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1247,7 +1247,7 @@ APInt APInt::sqrt() const {
/// however we simplify it to speed up calculating only the inverse, and take
/// advantage of div+rem calculations. We also use some tricks to avoid copying
/// (potentially large) APInts around.
-APInt APInt::multiplicativeInverse(const APInt& Modulus) const {
+APInt APInt::multiplicativeInverse(const APInt &Modulus) const {
assert(ult(Modulus) && "This APInt must be smaller than the modulus");
// Using the properties listed at the following web page (accessed 06/21/08):
@@ -1257,18 +1257,18 @@ APInt APInt::multiplicativeInverse(const APInt& Modulus) const {
// below. More precisely, this number of bits suffice if the multiplicative
// inverse exists, but may not suffice for the general extended Euclidean
// algorithm.
- APInt R[2] = { Modulus, *this };
- APInt T[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
+ APInt R[2] = {Modulus, *this};
+ APInt T[2] = {APInt(BitWidth, 0), APInt(BitWidth, 1)};
APInt Q(BitWidth, 0);
unsigned i;
- for (i = 0; R[i^1] != 0; i ^= 1) {
+ for (i = 0; R[i ^ 1] != 0; i ^= 1) {
// An overview of the math without the confusing bit-flipping:
// Q = R[i-2] / R[i-1]
// R[i] = R[i-2] % R[i-1]
// T[i] = T[i-2] - T[i-1] * Q
- udivrem(R[i], R[i^1], Q, R[i]);
- T[i] -= T[i^1] * Q;
+ udivrem(R[i], R[i ^ 1], Q, R[i]);
+ T[i] -= T[i ^ 1] * Q;
}
// If this APInt and the modulus are not coprime, there is no multiplicative
>From ea08b21f795ab795acf87e546586c6750babc4da Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sun, 7 Apr 2024 16:18:04 -0700
Subject: [PATCH 4/4] address more review comments
---
llvm/unittests/ADT/APIntTest.cpp | 24 ++++++++++++------------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index e15d3b9cdf3167..a366ca3601fe68 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3267,16 +3267,14 @@ TEST(APIntTest, ModularMultiplicativeInverseSpecific) {
int NonInvertibleElements[14] = {0, 2, 4, 6, 8, 10, 12,
13, 14, 16, 18, 20, 22, 24};
- for (size_t i = 0; i < 12; ++i) {
- APInt V(BitWidth, Values[i]);
- APInt Inv = V.multiplicativeInverse(Modulus);
- EXPECT_EQ(Inv, Inverses[i]);
+ for (auto [Val, ExpectedInv] : zip(Values, Inverses)) {
+ APInt V(BitWidth, Val);
+ EXPECT_EQ(V.multiplicativeInverse(Modulus), ExpectedInv);
}
- for (size_t i = 0; i < 14; ++i) {
- APInt V(BitWidth, NonInvertibleElements[i]);
- APInt Inv = V.multiplicativeInverse(Modulus);
- EXPECT_EQ(Inv, 0);
+ for (auto Val : NonInvertibleElements) {
+ APInt V(BitWidth, Val);
+ EXPECT_EQ(V.multiplicativeInverse(Modulus), 0);
}
}
@@ -3284,15 +3282,17 @@ TEST(APIntTest, ModularMultiplicativeInverseExaustive) {
// Test all moduli and all values up to 8 bits using a gcd test to determine
// if a multiplicative inverse exists.
int BitWidth = 8;
+
+ APInt M(BitWidth, 1);
+ APInt Z(BitWidth, 0);
+ EXPECT_TRUE(Z.multiplicativeInverse(M).isZero());
+
for (unsigned Modulus = 2; Modulus < (1u << BitWidth); ++Modulus) {
+ APInt M(BitWidth, Modulus);
for (unsigned Value = 0; Value < Modulus; ++Value) {
- APInt M(BitWidth, Modulus);
APInt V(BitWidth, Value);
- EXPECT_TRUE(V.ult(M))
- << "Expected " << V << " ult " << M << ", but it was not";
APInt MulInv = V.multiplicativeInverse(M);
if (APIntOps::GreatestCommonDivisor(V, M).isOne()) {
- EXPECT_FALSE(MulInv.isZero());
// Multiplication verification must take place in a larger bit width
APInt Actual = (V.zext(2 * BitWidth) * MulInv.zext(2 * BitWidth))
.urem(M.zext(2 * BitWidth));
More information about the llvm-commits
mailing list