[libc] [llvm] [libc] Refactor `BigInt` (PR #86137)

Nick Desaulniers via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 27 09:02:58 PDT 2024


================
@@ -191,200 +417,100 @@ struct BigInt {
   LIBC_INLINE constexpr cpp::enable_if_t<
       cpp::is_integral_v<T> && !cpp::is_same_v<T, bool>, T>
   to() const {
+    constexpr size_t T_SIZE = sizeof(T) * CHAR_BIT;
     T lo = static_cast<T>(val[0]);
-
-    constexpr size_t T_BITS = sizeof(T) * CHAR_BIT;
-
-    if constexpr (T_BITS <= WORD_SIZE)
+    if constexpr (T_SIZE <= WORD_SIZE)
       return lo;
-
     constexpr size_t MAX_COUNT =
-        T_BITS > Bits ? WORD_COUNT : T_BITS / WORD_SIZE;
+        T_SIZE > Bits ? WORD_COUNT : T_SIZE / WORD_SIZE;
     for (size_t i = 1; i < MAX_COUNT; ++i)
       lo += static_cast<T>(val[i]) << (WORD_SIZE * i);
-
-    if constexpr (Signed && (T_BITS > Bits)) {
+    if constexpr (Signed && (T_SIZE > Bits)) {
       // Extend sign for negative numbers.
       constexpr T MASK = (~T(0) << Bits);
       if (is_neg())
         lo |= MASK;
     }
-
     return lo;
   }
 
   LIBC_INLINE constexpr explicit operator bool() const { return !is_zero(); }
 
-  LIBC_INLINE constexpr BigInt &operator=(const BigInt &other) = default;
-
   LIBC_INLINE constexpr bool is_zero() const {
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      if (val[i] != 0)
+    for (auto part : val)
+      if (part != 0)
         return false;
-    }
     return true;
   }
 
-  // Add x to this number and store the result in this number.
+  // Add 'rhs' to this number and store the result in this number.
   // Returns the carry value produced by the addition operation.
-  LIBC_INLINE constexpr WordType add(const BigInt &x) {
-    SumCarry<WordType> s{0, 0};
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      s = add_with_carry(val[i], x.val[i], s.carry);
-      val[i] = s.sum;
-    }
-    return s.carry;
+  LIBC_INLINE constexpr WordType add_overflow(const BigInt &rhs) {
+    return multiword::add_with_carry(val, rhs.val);
   }
 
   LIBC_INLINE constexpr BigInt operator+(const BigInt &other) const {
-    BigInt result;
-    SumCarry<WordType> s{0, 0};
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      s = add_with_carry(val[i], other.val[i], s.carry);
-      result.val[i] = s.sum;
-    }
+    BigInt result = *this;
+    result.add_overflow(other);
     return result;
   }
 
   // This will only apply when initializing a variable from constant values, so
   // it will always use the constexpr version of add_with_carry.
   LIBC_INLINE constexpr BigInt operator+(BigInt &&other) const {
-    BigInt result;
-    SumCarry<WordType> s{0, 0};
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      s = add_with_carry(val[i], other.val[i], s.carry);
-      result.val[i] = s.sum;
-    }
-    return result;
+    // We use addition commutativity to reuse 'other' and prevent allocation.
+    other.add_overflow(*this); // Returned carry value is ignored.
+    return other;
   }
 
   LIBC_INLINE constexpr BigInt &operator+=(const BigInt &other) {
-    add(other); // Returned carry value is ignored.
+    add_overflow(other); // Returned carry value is ignored.
     return *this;
   }
 
-  // Subtract x to this number and store the result in this number.
+  // Subtract 'rhs' to this number and store the result in this number.
   // Returns the carry value produced by the subtraction operation.
-  LIBC_INLINE constexpr WordType sub(const BigInt &x) {
-    DiffBorrow<WordType> d{0, 0};
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      d = sub_with_borrow(val[i], x.val[i], d.borrow);
-      val[i] = d.diff;
-    }
-    return d.borrow;
+  LIBC_INLINE constexpr WordType sub_overflow(const BigInt &rhs) {
+    return multiword::sub_with_borrow(val, rhs.val);
   }
 
   LIBC_INLINE constexpr BigInt operator-(const BigInt &other) const {
-    BigInt result;
-    DiffBorrow<WordType> d{0, 0};
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      d = sub_with_borrow(val[i], other.val[i], d.borrow);
-      result.val[i] = d.diff;
-    }
+    BigInt result = *this;
+    result.sub_overflow(other); // Returned carry value is ignored.
     return result;
   }
 
   LIBC_INLINE constexpr BigInt operator-(BigInt &&other) const {
-    BigInt result;
-    DiffBorrow<WordType> d{0, 0};
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      d = sub_with_borrow(val[i], other.val[i], d.borrow);
-      result.val[i] = d.diff;
-    }
+    BigInt result = *this;
+    result.sub_overflow(other); // Returned carry value is ignored.
     return result;
   }
 
   LIBC_INLINE constexpr BigInt &operator-=(const BigInt &other) {
     // TODO(lntue): Set overflow flag / errno when carry is true.
-    sub(other);
+    sub_overflow(other); // Returned carry value is ignored.
     return *this;
   }
 
-  // Multiply this number with x and store the result in this number. It is
-  // implemented using the long multiplication algorithm by splitting the
-  // 64-bit words of this number and |x| in to 32-bit halves but peforming
-  // the operations using 64-bit numbers. This ensures that we don't lose the
-  // carry bits.
-  // Returns the carry value produced by the multiplication operation.
+  // Multiply this number with x and store the result in this number.
   LIBC_INLINE constexpr WordType mul(WordType x) {
-    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      NumberPair<WordType> prod = internal::full_mul(val[i], x);
-      BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
-      const WordType carry = partial_sum.add(tmp);
-      val[i] = partial_sum.val[0];
-      partial_sum.val[0] = partial_sum.val[1];
-      partial_sum.val[1] = carry;
-    }
-    return partial_sum.val[1];
+    return multiword::scalar_multiply_with_carry(val, x);
   }
 
-  LIBC_INLINE constexpr BigInt operator*(const BigInt &other) const {
-    if constexpr (Signed) {
-      BigInt<Bits, false, WordType> a(*this);
-      BigInt<Bits, false, WordType> b(other);
-      const bool a_neg = a.is_neg();
-      const bool b_neg = b.is_neg();
-      if (a_neg)
-        a = -a;
-      if (b_neg)
-        b = -b;
-      BigInt<Bits, false, WordType> prod = a * b;
-      if (a_neg != b_neg)
-        prod = -prod;
-      return static_cast<BigInt<Bits, true, WordType>>(prod);
-    } else {
-      if constexpr (WORD_COUNT == 1) {
-        return {val[0] * other.val[0]};
-      } else {
-        BigInt result(0);
-        BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
-        WordType carry = 0;
-        for (size_t i = 0; i < WORD_COUNT; ++i) {
-          for (size_t j = 0; j <= i; j++) {
-            NumberPair<WordType> prod =
-                internal::full_mul(val[j], other.val[i - j]);
-            BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
-            carry += partial_sum.add(tmp);
-          }
-          result.val[i] = partial_sum.val[0];
-          partial_sum.val[0] = partial_sum.val[1];
-          partial_sum.val[1] = carry;
-          carry = 0;
-        }
-        return result;
-      }
-    }
-  }
-
-  // Return the full product, only unsigned for now.
+  // Return the full product.
   template <size_t OtherBits>
-  LIBC_INLINE constexpr BigInt<Bits + OtherBits, Signed, WordType>
+  LIBC_INLINE constexpr auto
   ful_mul(const BigInt<OtherBits, Signed, WordType> &other) const {
-    BigInt<Bits + OtherBits, Signed, WordType> result(0);
-    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
-    WordType carry = 0;
-    constexpr size_t OTHER_WORDCOUNT =
-        BigInt<OtherBits, Signed, WordType>::WORD_COUNT;
-    for (size_t i = 0; i <= WORD_COUNT + OTHER_WORDCOUNT - 2; ++i) {
-      const size_t lower_idx =
-          i < OTHER_WORDCOUNT ? 0 : i - OTHER_WORDCOUNT + 1;
-      const size_t upper_idx = i < WORD_COUNT ? i : WORD_COUNT - 1;
-      for (size_t j = lower_idx; j <= upper_idx; ++j) {
-        NumberPair<WordType> prod =
-            internal::full_mul(val[j], other.val[i - j]);
-        BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
-        carry += partial_sum.add(tmp);
-      }
-      result.val[i] = partial_sum.val[0];
-      partial_sum.val[0] = partial_sum.val[1];
-      partial_sum.val[1] = carry;
-      carry = 0;
-    }
-    result.val[WORD_COUNT + OTHER_WORDCOUNT - 1] = partial_sum.val[0];
+    BigInt<Bits + OtherBits, Signed, WordType> result;
----------------
nickdesaulniers wrote:

Perhaps a `decltype` here would be more readable?

https://github.com/llvm/llvm-project/pull/86137


More information about the llvm-commits mailing list