[libc-commits] [libc] [libc][NFC] Simplify BigInt (PR #81992)

Guillaume Chatelet via libc-commits libc-commits at lists.llvm.org
Fri Feb 16 05:43:53 PST 2024


https://github.com/gchatelet created https://github.com/llvm/llvm-project/pull/81992

None

>From 5ef44fa2a4ada50e86f175a8c4ad4ddccaaa194f Mon Sep 17 00:00:00 2001
From: Guillaume Chatelet <gchatelet at google.com>
Date: Fri, 16 Feb 2024 13:43:28 +0000
Subject: [PATCH] [libc][NFC] Simplify BigInt

---
 libc/src/__support/CPP/array.h |   4 +-
 libc/src/__support/UInt.h      | 216 +++++++++++----------------------
 2 files changed, 74 insertions(+), 146 deletions(-)

diff --git a/libc/src/__support/CPP/array.h b/libc/src/__support/CPP/array.h
index 1897066514092c..fb5a79225beb7d 100644
--- a/libc/src/__support/CPP/array.h
+++ b/libc/src/__support/CPP/array.h
@@ -28,10 +28,10 @@ template <class T, size_t N> struct array {
   LIBC_INLINE constexpr const T *data() const { return Data; }
 
   LIBC_INLINE constexpr T &front() { return Data[0]; }
-  LIBC_INLINE constexpr T &front() const { return Data[0]; }
+  LIBC_INLINE constexpr const T &front() const { return Data[0]; }
 
   LIBC_INLINE constexpr T &back() { return Data[N - 1]; }
-  LIBC_INLINE constexpr T &back() const { return Data[N - 1]; }
+  LIBC_INLINE constexpr const T &back() const { return Data[N - 1]; }
 
   LIBC_INLINE constexpr T &operator[](size_t Index) { return Data[Index]; }
 
diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h
index b90275035a23ea..f2084300d14208 100644
--- a/libc/src/__support/UInt.h
+++ b/libc/src/__support/UInt.h
@@ -68,8 +68,8 @@ struct BigInt {
         val[i] = other[i];
       WordType sign = 0;
       if constexpr (Signed && OtherSigned) {
-        sign = static_cast<WordType>(-static_cast<make_signed_t<WordType>>(
-            other[OtherBits / WORD_SIZE - 1] >> (WORD_SIZE - 1)));
+        sign = static_cast<WordType>(
+            -static_cast<make_signed_t<WordType>>(other.is_neg()));
       }
       for (; i < WORD_COUNT; ++i)
         val[i] = sign;
@@ -125,6 +125,11 @@ struct BigInt {
       val[i] = words[i];
   }
 
+  // TODO: Reuse the Sign type.
+  LIBC_INLINE constexpr bool is_neg() const {
+    return val.back() >> (WORD_SIZE - 1);
+  }
+
   template <typename T> LIBC_INLINE constexpr explicit operator T() const {
     return to<T>();
   }
@@ -148,7 +153,7 @@ struct BigInt {
     if constexpr (Signed && (T_BITS > Bits)) {
       // Extend sign for negative numbers.
       constexpr T MASK = (~T(0) << Bits);
-      if (val[WORD_COUNT - 1] >> (WORD_SIZE - 1))
+      if (is_neg())
         lo |= MASK;
     }
 
@@ -267,8 +272,8 @@ struct BigInt {
     if constexpr (Signed) {
       BigInt<Bits, false, WordType> a(*this);
       BigInt<Bits, false, WordType> b(other);
-      bool a_neg = (a.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
-      bool b_neg = (b.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
+      const bool a_neg = a.is_neg();
+      const bool b_neg = b.is_neg();
       if (a_neg)
         a = -a;
       if (b_neg)
@@ -278,7 +283,6 @@ struct BigInt {
         prod = -prod;
       return static_cast<BigInt<Bits, true, WordType>>(prod);
     } else {
-
       if constexpr (WORD_COUNT == 1) {
         return {val[0] * other.val[0]};
       } else {
@@ -383,10 +387,9 @@ struct BigInt {
     BigInt cur_power = *this;
 
     while (power > 0) {
-      if ((power % 2) > 0) {
-        result = result * cur_power;
-      }
-      power = power >> 1;
+      if ((power % 2) > 0)
+        result *= cur_power;
+      power >>= 1;
       cur_power *= cur_power;
     }
     *this = result;
@@ -709,7 +712,7 @@ struct BigInt {
     const size_t shift = s % WORD_SIZE; // Bit shift in the remaining words.
 
     size_t i = 0;
-    WordType sign = Signed ? (val[WORD_COUNT - 1] >> (WORD_SIZE - 1)) : 0;
+    WordType sign = Signed ? is_neg() : 0;
 
     if (drop < WORD_COUNT) {
       if (shift > 0) {
@@ -747,49 +750,31 @@ struct BigInt {
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt operator&(const BigInt &other) const {
-    BigInt result;
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = val[i] & other.val[i];
-    return result;
+#define DEFINE_BINOP(OP)                                                       \
+  LIBC_INLINE friend constexpr BigInt operator OP(const BigInt &lhs,           \
+                                                  const BigInt &rhs) {         \
+    BigInt result;                                                             \
+    for (size_t i = 0; i < WORD_COUNT; ++i)                                    \
+      result[i] = lhs[i] OP rhs[i];                                            \
+    return result;                                                             \
+  }                                                                            \
+  LIBC_INLINE friend constexpr BigInt operator OP##=(BigInt &lhs,              \
+                                                     const BigInt &rhs) {      \
+    for (size_t i = 0; i < WORD_COUNT; ++i)                                    \
+      lhs[i] OP## = rhs[i];                                                    \
+    return lhs;                                                                \
   }
 
-  LIBC_INLINE constexpr BigInt &operator&=(const BigInt &other) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] &= other.val[i];
-    return *this;
-  }
+  DEFINE_BINOP(&)
+  DEFINE_BINOP(|)
+  DEFINE_BINOP(^)
 
-  LIBC_INLINE constexpr BigInt operator|(const BigInt &other) const {
-    BigInt result;
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = val[i] | other.val[i];
-    return result;
-  }
-
-  LIBC_INLINE constexpr BigInt &operator|=(const BigInt &other) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] |= other.val[i];
-    return *this;
-  }
-
-  LIBC_INLINE constexpr BigInt operator^(const BigInt &other) const {
-    BigInt result;
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = val[i] ^ other.val[i];
-    return result;
-  }
-
-  LIBC_INLINE constexpr BigInt &operator^=(const BigInt &other) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] ^= other.val[i];
-    return *this;
-  }
+#undef DEFINE_BINOP
 
   LIBC_INLINE constexpr BigInt operator~() const {
     BigInt result;
     for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = ~val[i];
+      result[i] = ~val[i];
     return result;
   }
 
@@ -799,139 +784,82 @@ struct BigInt {
     return result;
   }
 
-  LIBC_INLINE constexpr bool operator==(const BigInt &other) const {
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      if (val[i] != other.val[i])
+  LIBC_INLINE friend constexpr bool operator==(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
+      if (lhs.val[i] != rhs.val[i])
         return false;
-    }
     return true;
   }
 
-  LIBC_INLINE constexpr bool operator!=(const BigInt &other) const {
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      if (val[i] != other.val[i])
-        return true;
-    }
-    return false;
+  LIBC_INLINE friend constexpr bool operator!=(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    return !(lhs == rhs);
   }
 
-  LIBC_INLINE constexpr bool operator>(const BigInt &other) const {
+private:
+  LIBC_INLINE friend constexpr int cmp(const BigInt &lhs, const BigInt &rhs) {
+    const auto compare = [](WordType a, WordType b) {
+      return a == b ? 0 : a > b ? 1 : -1;
+    };
     if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return b_sign;
-      }
-    }
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return true;
-      else if (word < other_word)
-        return false;
+      const bool lhs_is_neg = lhs.is_neg();
+      const bool rhs_is_neg = rhs.is_neg();
+      if (lhs_is_neg != rhs_is_neg)
+        return rhs_is_neg ? 1 : -1;
     }
-    // Equal
-    return false;
+    for (size_t i = WORD_COUNT; i-- > 0;)
+      if (auto cmp = compare(lhs[i], rhs[i]); cmp != 0)
+        return cmp;
+    return 0;
   }
 
-  LIBC_INLINE constexpr bool operator>=(const BigInt &other) const {
-    if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return b_sign;
-      }
-    }
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return true;
-      else if (word < other_word)
-        return false;
-    }
-    // Equal
-    return true;
+public:
+  LIBC_INLINE friend constexpr bool operator>(const BigInt &lhs,
+                                              const BigInt &rhs) {
+    return cmp(lhs, rhs) > 0;
   }
-
-  LIBC_INLINE constexpr bool operator<(const BigInt &other) const {
-    if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return a_sign;
-      }
-    }
-
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return false;
-      else if (word < other_word)
-        return true;
-    }
-    // Equal
-    return false;
+  LIBC_INLINE friend constexpr bool operator>=(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    return cmp(lhs, rhs) >= 0;
   }
-
-  LIBC_INLINE constexpr bool operator<=(const BigInt &other) const {
-    if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return a_sign;
-      }
-    }
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return false;
-      else if (word < other_word)
-        return true;
-    }
-    // Equal
-    return true;
+  LIBC_INLINE friend constexpr bool operator<(const BigInt &lhs,
+                                              const BigInt &rhs) {
+    return cmp(lhs, rhs) < 0;
+  }
+  LIBC_INLINE friend constexpr bool operator<=(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    return cmp(lhs, rhs) <= 0;
   }
 
   LIBC_INLINE constexpr BigInt &operator++() {
-    BigInt one(1);
-    add(one);
+    add(BigInt(1));
     return *this;
   }
 
   LIBC_INLINE constexpr BigInt operator++(int) {
     BigInt oldval(*this);
-    BigInt one(1);
-    add(one);
+    add(BigInt(1));
     return oldval;
   }
 
   LIBC_INLINE constexpr BigInt &operator--() {
-    BigInt one(1);
-    sub(one);
+    sub(BigInt(1));
     return *this;
   }
 
   LIBC_INLINE constexpr BigInt operator--(int) {
     BigInt oldval(*this);
-    BigInt one(1);
-    sub(one);
+    sub(BigInt(1));
     return oldval;
   }
 
-  // Return the i-th 64-bit word of the number.
+  // Return the i-th word of the number.
   LIBC_INLINE constexpr const WordType &operator[](size_t i) const {
     return val[i];
   }
 
-  // Return the i-th 64-bit word of the number.
+  // Return the i-th word of the number.
   LIBC_INLINE constexpr WordType &operator[](size_t i) { return val[i]; }
 
   LIBC_INLINE WordType *data() { return val; }



More information about the libc-commits mailing list