[libc-commits] [libc] [libc] Allow BigInt class to use base word types other than uint64_t. (PR #81634)

via libc-commits libc-commits at lists.llvm.org
Tue Feb 13 12:31:17 PST 2024


https://github.com/lntue updated https://github.com/llvm/llvm-project/pull/81634

>From 94d7336b398d24c7d492dbf02f6a437bacf1fb09 Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue at google.com>
Date: Tue, 13 Feb 2024 12:04:13 -0500
Subject: [PATCH 1/2] [libc] Allow BigInt class to use base word types other
 than uint64_t.

---
 libc/src/__support/FPUtil/dyadic_float.h |   6 +-
 libc/src/__support/UInt.h                | 733 ++++++++++++-----------
 libc/src/__support/float_to_string.h     |  22 +-
 libc/test/src/__support/uint_test.cpp    |  17 +-
 4 files changed, 420 insertions(+), 358 deletions(-)

diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h
index 888d7ffec241ea..a8b3ad7a16d3bb 100644
--- a/libc/src/__support/FPUtil/dyadic_float.h
+++ b/libc/src/__support/FPUtil/dyadic_float.h
@@ -216,7 +216,7 @@ constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
     if (result.mantissa.add(b.mantissa)) {
       // Mantissa addition overflow.
       result.shift_right(1);
-      result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORDCOUNT - 1] |=
+      result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
           (uint64_t(1) << 63);
     }
     // Result is already normalized.
@@ -243,7 +243,7 @@ constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
 //   result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
 //                   ~ (full product a.mantissa * b.mantissa) >> Bits.
 // The errors compared to the mathematical product is bounded by:
-//   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORDCOUNT - 1) in ULPs.
+//   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
 // Assume inputs are normalized (by constructors or other functions) so that we
 // don't need to normalize the inputs again in this function.  If the inputs are
 // not normalized, the results might lose precision significantly.
@@ -258,7 +258,7 @@ constexpr DyadicFloat<Bits> quick_mul(DyadicFloat<Bits> a,
     result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
     // Check the leading bit directly, should be faster than using clz in
     // normalize().
-    if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORDCOUNT - 1] >>
+    if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
             63 ==
         0)
       result.shift_left(1);
diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h
index 7726b6d88f0d21..5a60ea0e6d8135 100644
--- a/libc/src/__support/UInt.h
+++ b/libc/src/__support/UInt.h
@@ -25,35 +25,30 @@
 
 namespace LIBC_NAMESPACE::cpp {
 
-template <size_t Bits, bool Signed> struct BigInt {
+template <size_t Bits, bool Signed, typename WordType = uint64_t>
+struct BigInt {
+  static_assert(is_integral_v<WordType> && is_unsigned_v<WordType>,
+                "WordType must be unsigned integer.");
 
-  // This being hardcoded as 64 is okay because we're using uint64_t as our
-  // internal type which will always be 64 bits.
-  using word_type = uint64_t;
-  LIBC_INLINE_VAR static constexpr size_t WORD_SIZE =
-      sizeof(word_type) * CHAR_BIT;
+  LIBC_INLINE_VAR
+  static constexpr size_t WORD_SIZE = sizeof(WordType) * CHAR_BIT;
 
-  // TODO: Replace references to 64 with WORD_SIZE, and uint64_t with word_type.
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in BigInt should be a multiple of 64.");
-  LIBC_INLINE_VAR static constexpr size_t WORDCOUNT = Bits / 64;
-  cpp::array<word_type, WORDCOUNT> val{};
+  static_assert(Bits > 0 && Bits % WORD_SIZE == 0,
+                "Number of bits in BigInt should be a multiple of WORD_SIZE.");
 
-  LIBC_INLINE_VAR static constexpr uint64_t MASK32 = 0xFFFFFFFFu;
-
-  LIBC_INLINE static constexpr uint64_t low(uint64_t v) { return v & MASK32; }
-  LIBC_INLINE static constexpr uint64_t high(uint64_t v) {
-    return (v >> 32) & MASK32;
-  }
+  LIBC_INLINE_VAR static constexpr size_t WORD_COUNT = Bits / WORD_SIZE;
+  cpp::array<WordType, WORD_COUNT> val{};
 
   LIBC_INLINE constexpr BigInt() = default;
 
-  LIBC_INLINE constexpr BigInt(const BigInt<Bits, Signed> &other) = default;
+  LIBC_INLINE constexpr BigInt(const BigInt<Bits, Signed, WordType> &other) =
+      default;
 
   template <size_t OtherBits, bool OtherSigned>
-  LIBC_INLINE constexpr BigInt(const BigInt<OtherBits, OtherSigned> &other) {
+  LIBC_INLINE constexpr BigInt(
+      const BigInt<OtherBits, OtherSigned, WordType> &other) {
     if (OtherBits >= Bits) {
-      for (size_t i = 0; i < WORDCOUNT; ++i)
+      for (size_t i = 0; i < WORD_COUNT; ++i)
         val[i] = other[i];
     } else {
       size_t i = 0;
@@ -64,49 +59,57 @@ template <size_t Bits, bool Signed> struct BigInt {
         sign = static_cast<uint64_t>(
             -static_cast<int64_t>(other[OtherBits / 64 - 1] >> 63));
       }
-      for (; i < WORDCOUNT; ++i)
+      for (; i < WORD_COUNT; ++i)
         val[i] = sign;
     }
   }
 
   // Construct a BigInt from a C array.
-  template <size_t N, enable_if_t<N <= WORDCOUNT, int> = 0>
-  LIBC_INLINE constexpr BigInt(const uint64_t (&nums)[N]) {
-    size_t min_wordcount = N < WORDCOUNT ? N : WORDCOUNT;
+  template <size_t N, enable_if_t<N <= WORD_COUNT, int> = 0>
+  LIBC_INLINE constexpr BigInt(const WordType (&nums)[N]) {
+    size_t min_wordcount = N < WORD_COUNT ? N : WORD_COUNT;
     size_t i = 0;
     for (; i < min_wordcount; ++i)
       val[i] = nums[i];
 
     // If nums doesn't completely fill val, then fill the rest with zeroes.
-    for (; i < WORDCOUNT; ++i)
+    for (; i < WORD_COUNT; ++i)
       val[i] = 0;
   }
 
   // Initialize the first word to |v| and the rest to 0.
-  template <typename T,
-            typename = cpp::enable_if_t<is_integral_v<T> && sizeof(T) <= 16>>
+  template <typename T, typename = cpp::enable_if_t<is_integral_v<T>>>
   LIBC_INLINE constexpr BigInt(T v) {
-    val[0] = static_cast<uint64_t>(v);
+    val[0] = static_cast<WordType>(v);
 
-    if constexpr (Bits == 64)
+    if constexpr (WORD_COUNT == 1)
       return;
 
-    // Bits is at least 128.
-    size_t i = 1;
-    if constexpr (sizeof(T) == 16) {
-      val[1] = static_cast<uint64_t>(v >> 64);
-      i = 2;
+    if constexpr (Bits < sizeof(T) * CHAR_BIT) {
+      for (int i = 1; i < WORD_COUNT; ++i) {
+        v >>= WORD_SIZE;
+        val[i] = static_cast<WordType>(v);
+      }
+      return;
     }
 
-    uint64_t sign = (Signed && (v < 0)) ? 0xffff'ffff'ffff'ffff : 0;
-    for (; i < WORDCOUNT; ++i) {
+    size_t i = 1;
+
+    if constexpr (WORD_SIZE < sizeof(T) * CHAR_BIT)
+      for (; i < sizeof(T) * CHAR_BIT / WORD_SIZE; ++i) {
+        v >>= WORD_SIZE;
+        val[i] = static_cast<WordType>(v);
+      }
+
+    WordType sign = (Signed && (v < 0)) ? ~WordType(0) : WordType(0);
+    for (; i < WORD_COUNT; ++i) {
       val[i] = sign;
     }
   }
 
   LIBC_INLINE constexpr explicit BigInt(
-      const cpp::array<uint64_t, WORDCOUNT> &words) {
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+      const cpp::array<WordType, WORD_COUNT> &words) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       val[i] = words[i];
   }
 
@@ -116,36 +119,37 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   template <typename T>
   LIBC_INLINE constexpr cpp::enable_if_t<
-      cpp::is_integral_v<T> && sizeof(T) <= 8 && !cpp::is_same_v<T, bool>, T>
+      cpp::is_integral_v<T> && !cpp::is_same_v<T, bool>, T>
   to() const {
-    return static_cast<T>(val[0]);
-  }
-  template <typename T>
-  LIBC_INLINE constexpr cpp::enable_if_t<
-      cpp::is_integral_v<T> && sizeof(T) == 16, T>
-  to() const {
-    // T is 128-bit.
     T lo = static_cast<T>(val[0]);
 
-    if constexpr (Bits == 64) {
-      if constexpr (Signed) {
-        // Extend sign for negative numbers.
-        return (val[0] >> 63) ? ((T(-1) << 64) + lo) : lo;
-      } else {
-        return lo;
-      }
-    } else {
-      return static_cast<T>((static_cast<T>(val[1]) << 64) + lo);
+    constexpr size_t T_BITS = sizeof(T) * CHAR_BIT;
+
+    if constexpr (T_BITS <= WORD_SIZE)
+      return lo;
+
+    constexpr size_t MAX_COUNT =
+        T_BITS > Bits ? WORD_COUNT : T_BITS / WORD_SIZE;
+    for (size_t i = 1; i < MAX_COUNT; ++i)
+      lo += static_cast<T>(val[i]) << (WORD_SIZE * i);
+
+    if constexpr (Signed && (T_BITS > Bits)) {
+      // Extend sign for negative numbers.
+      constexpr T MASK = (~T(0) << Bits);
+      if (val[WORD_COUNT - 1] >> (WORD_SIZE - 1))
+        lo |= MASK;
     }
+
+    return lo;
   }
 
   LIBC_INLINE constexpr explicit operator bool() const { return !is_zero(); }
 
-  LIBC_INLINE BigInt<Bits, Signed> &
-  operator=(const BigInt<Bits, Signed> &other) = default;
+  LIBC_INLINE BigInt<Bits, Signed, WordType> &
+  operator=(const BigInt<Bits, Signed, WordType> &other) = default;
 
   LIBC_INLINE constexpr bool is_zero() const {
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       if (val[i] != 0)
         return false;
     }
@@ -154,20 +158,20 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // Add x to this number and store the result in this number.
   // Returns the carry value produced by the addition operation.
-  LIBC_INLINE constexpr uint64_t add(const BigInt<Bits, Signed> &x) {
-    SumCarry<uint64_t> s{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr WordType add(const BigInt<Bits, Signed, WordType> &x) {
+    SumCarry<WordType> s{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       s = add_with_carry_const(val[i], x.val[i], s.carry);
       val[i] = s.sum;
     }
     return s.carry;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator+(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    SumCarry<uint64_t> s{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator+(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    SumCarry<WordType> s{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       s = add_with_carry(val[i], other.val[i], s.carry);
       result.val[i] = s.sum;
     }
@@ -176,58 +180,58 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // This will only apply when initializing a variable from constant values, so
   // it will always use the constexpr version of add_with_carry.
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator+(BigInt<Bits, Signed> &&other) const {
-    BigInt<Bits, Signed> result;
-    SumCarry<uint64_t> s{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator+(BigInt<Bits, Signed, WordType> &&other) const {
+    BigInt<Bits, Signed, WordType> result;
+    SumCarry<WordType> s{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       s = add_with_carry_const(val[i], other.val[i], s.carry);
       result.val[i] = s.sum;
     }
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator+=(const BigInt<Bits, Signed> &other) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator+=(const BigInt<Bits, Signed, WordType> &other) {
     add(other); // Returned carry value is ignored.
     return *this;
   }
 
   // Subtract x to this number and store the result in this number.
   // Returns the carry value produced by the subtraction operation.
-  LIBC_INLINE constexpr uint64_t sub(const BigInt<Bits, Signed> &x) {
-    DiffBorrow<uint64_t> d{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr WordType sub(const BigInt<Bits, Signed, WordType> &x) {
+    DiffBorrow<WordType> d{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       d = sub_with_borrow_const(val[i], x.val[i], d.borrow);
       val[i] = d.diff;
     }
     return d.borrow;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator-(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    DiffBorrow<uint64_t> d{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator-(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    DiffBorrow<WordType> d{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       d = sub_with_borrow(val[i], other.val[i], d.borrow);
       result.val[i] = d.diff;
     }
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator-(BigInt<Bits, Signed> &&other) const {
-    BigInt<Bits, Signed> result;
-    DiffBorrow<uint64_t> d{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator-(BigInt<Bits, Signed, WordType> &&other) const {
+    BigInt<Bits, Signed, WordType> result;
+    DiffBorrow<WordType> d{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       d = sub_with_borrow_const(val[i], other.val[i], d.borrow);
       result.val[i] = d.diff;
     }
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator-=(const BigInt<Bits, Signed> &other) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator-=(const BigInt<Bits, Signed, WordType> &other) {
     // TODO(lntue): Set overflow flag / errno when carry is true.
     sub(other);
     return *this;
@@ -239,12 +243,12 @@ template <size_t Bits, bool Signed> struct BigInt {
   // the operations using 64-bit numbers. This ensures that we don't lose the
   // carry bits.
   // Returns the carry value produced by the multiplication operation.
-  LIBC_INLINE constexpr uint64_t mul(uint64_t x) {
-    BigInt<128, Signed> partial_sum(0);
-    uint64_t carry = 0;
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
-      NumberPair<uint64_t> prod = full_mul(val[i], x);
-      BigInt<128, Signed> tmp({prod.lo, prod.hi});
+  LIBC_INLINE constexpr WordType mul(WordType x) {
+    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+    WordType carry = 0;
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
+      NumberPair<WordType> prod = full_mul(val[i], x);
+      BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
       carry += partial_sum.add(tmp);
       val[i] = partial_sum.val[0];
       partial_sum.val[0] = partial_sum.val[1];
@@ -254,33 +258,33 @@ template <size_t Bits, bool Signed> struct BigInt {
     return partial_sum.val[1];
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator*(const BigInt<Bits, Signed> &other) const {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator*(const BigInt<Bits, Signed, WordType> &other) const {
     if constexpr (Signed) {
-      BigInt<Bits, false> a(*this);
-      BigInt<Bits, false> b(other);
-      bool a_neg = (a.val[WORDCOUNT - 1] >> 63);
-      bool b_neg = (b.val[WORDCOUNT - 1] >> 63);
+      BigInt<Bits, false, WordType> a(*this);
+      BigInt<Bits, false, WordType> b(other);
+      bool a_neg = (a.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
+      bool b_neg = (b.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
       if (a_neg)
         a = -a;
       if (b_neg)
         b = -b;
-      BigInt<Bits, false> prod = a * b;
+      BigInt<Bits, false, WordType> prod = a * b;
       if (a_neg != b_neg)
         prod = -prod;
-      return static_cast<BigInt<Bits, true>>(prod);
+      return static_cast<BigInt<Bits, true, WordType>>(prod);
     } else {
 
-      if constexpr (WORDCOUNT == 1) {
+      if constexpr (WORD_COUNT == 1) {
         return {val[0] * other.val[0]};
       } else {
-        BigInt<Bits, Signed> result(0);
-        BigInt<128, Signed> partial_sum(0);
-        uint64_t carry = 0;
-        for (size_t i = 0; i < WORDCOUNT; ++i) {
+        BigInt<Bits, Signed, WordType> result(0);
+        BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+        WordType carry = 0;
+        for (size_t i = 0; i < WORD_COUNT; ++i) {
           for (size_t j = 0; j <= i; j++) {
-            NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
-            BigInt<128, Signed> tmp({prod.lo, prod.hi});
+            NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
+            BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
             carry += partial_sum.add(tmp);
           }
           result.val[i] = partial_sum.val[0];
@@ -295,19 +299,20 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // Return the full product, only unsigned for now.
   template <size_t OtherBits>
-  LIBC_INLINE constexpr BigInt<Bits + OtherBits, Signed>
-  ful_mul(const BigInt<OtherBits, Signed> &other) const {
-    BigInt<Bits + OtherBits, Signed> result(0);
-    BigInt<128, Signed> partial_sum(0);
-    uint64_t carry = 0;
-    constexpr size_t OTHER_WORDCOUNT = BigInt<OtherBits, Signed>::WORDCOUNT;
-    for (size_t i = 0; i <= WORDCOUNT + OTHER_WORDCOUNT - 2; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits + OtherBits, Signed, WordType>
+  ful_mul(const BigInt<OtherBits, Signed, WordType> &other) const {
+    BigInt<Bits + OtherBits, Signed, WordType> result(0);
+    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+    WordType carry = 0;
+    constexpr size_t OTHER_WORDCOUNT =
+        BigInt<OtherBits, Signed, WordType>::WORD_COUNT;
+    for (size_t i = 0; i <= WORD_COUNT + OTHER_WORDCOUNT - 2; ++i) {
       const size_t lower_idx =
           i < OTHER_WORDCOUNT ? 0 : i - OTHER_WORDCOUNT + 1;
-      const size_t upper_idx = i < WORDCOUNT ? i : WORDCOUNT - 1;
+      const size_t upper_idx = i < WORD_COUNT ? i : WORD_COUNT - 1;
       for (size_t j = lower_idx; j <= upper_idx; ++j) {
-        NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
-        BigInt<128, Signed> tmp({prod.lo, prod.hi});
+        NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
+        BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
         carry += partial_sum.add(tmp);
       }
       result.val[i] = partial_sum.val[0];
@@ -315,7 +320,7 @@ template <size_t Bits, bool Signed> struct BigInt {
       partial_sum.val[1] = carry;
       carry = 0;
     }
-    result.val[WORDCOUNT + OTHER_WORDCOUNT - 1] = partial_sum.val[0];
+    result.val[WORD_COUNT + OTHER_WORDCOUNT - 1] = partial_sum.val[0];
     return result;
   }
 
@@ -323,7 +328,7 @@ template <size_t Bits, bool Signed> struct BigInt {
   // `Bits` least significant bits of the full product, while this function will
   // approximate `Bits` most significant bits of the full product with errors
   // bounded by:
-  //   0 <= (a.full_mul(b) >> Bits) - a.quick_mul_hi(b)) <= WORDCOUNT - 1.
+  //   0 <= (a.full_mul(b) >> Bits) - a.quick_mul_hi(b)) <= WORD_COUNT - 1.
   //
   // An example usage of this is to quickly (but less accurately) compute the
   // product of (normalized) mantissas of floating point numbers:
@@ -335,44 +340,44 @@ template <size_t Bits, bool Signed> struct BigInt {
   //
   // Performance summary:
   //   Number of 64-bit x 64-bit -> 128-bit multiplications performed.
-  //   Bits  WORDCOUNT  ful_mul  quick_mul_hi  Error bound
+  //   Bits  WORD_COUNT  ful_mul  quick_mul_hi  Error bound
   //    128      2         4           3            1
   //    196      3         9           6            2
   //    256      4        16          10            3
   //    512      8        64          36            7
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  quick_mul_hi(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result(0);
-    BigInt<128, Signed> partial_sum(0);
-    uint64_t carry = 0;
-    // First round of accumulation for those at WORDCOUNT - 1 in the full
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  quick_mul_hi(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result(0);
+    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+    WordType carry = 0;
+    // First round of accumulation for those at WORD_COUNT - 1 in the full
     // product.
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
-      NumberPair<uint64_t> prod =
-          full_mul(val[i], other.val[WORDCOUNT - 1 - i]);
-      BigInt<128, Signed> tmp({prod.lo, prod.hi});
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
+      NumberPair<WordType> prod =
+          full_mul(val[i], other.val[WORD_COUNT - 1 - i]);
+      BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
       carry += partial_sum.add(tmp);
     }
-    for (size_t i = WORDCOUNT; i < 2 * WORDCOUNT - 1; ++i) {
+    for (size_t i = WORD_COUNT; i < 2 * WORD_COUNT - 1; ++i) {
       partial_sum.val[0] = partial_sum.val[1];
       partial_sum.val[1] = carry;
       carry = 0;
-      for (size_t j = i - WORDCOUNT + 1; j < WORDCOUNT; ++j) {
-        NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
-        BigInt<128, Signed> tmp({prod.lo, prod.hi});
+      for (size_t j = i - WORD_COUNT + 1; j < WORD_COUNT; ++j) {
+        NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
+        BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
         carry += partial_sum.add(tmp);
       }
-      result.val[i - WORDCOUNT] = partial_sum.val[0];
+      result.val[i - WORD_COUNT] = partial_sum.val[0];
     }
-    result.val[WORDCOUNT - 1] = partial_sum.val[1];
+    result.val[WORD_COUNT - 1] = partial_sum.val[1];
     return result;
   }
 
   // pow takes a power and sets this to its starting value to that power. Zero
   // to the zeroth power returns 1.
   LIBC_INLINE constexpr void pow_n(uint64_t power) {
-    BigInt<Bits, Signed> result = 1;
-    BigInt<Bits, Signed> cur_power = *this;
+    BigInt<Bits, Signed, WordType> result = 1;
+    BigInt<Bits, Signed, WordType> cur_power = *this;
 
     while (power > 0) {
       if ((power % 2) > 0) {
@@ -388,12 +393,12 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // div takes another BigInt of the same size and divides this by it. The value
   // of this will be set to the quotient, and the return value is the remainder.
-  LIBC_INLINE constexpr optional<BigInt<Bits, Signed>>
-  div(const BigInt<Bits, Signed> &other) {
-    BigInt<Bits, Signed> remainder(0);
+  LIBC_INLINE constexpr optional<BigInt<Bits, Signed, WordType>>
+  div(const BigInt<Bits, Signed, WordType> &other) {
+    BigInt<Bits, Signed, WordType> remainder(0);
     if (*this < other) {
       remainder = *this;
-      *this = BigInt<Bits, Signed>(0);
+      *this = BigInt<Bits, Signed, WordType>(0);
       return remainder;
     }
     if (other == 1) {
@@ -403,15 +408,15 @@ template <size_t Bits, bool Signed> struct BigInt {
       return nullopt;
     }
 
-    BigInt<Bits, Signed> quotient(0);
-    BigInt<Bits, Signed> subtractor = other;
+    BigInt<Bits, Signed, WordType> quotient(0);
+    BigInt<Bits, Signed, WordType> subtractor = other;
     int cur_bit = static_cast<int>(subtractor.clz() - this->clz());
     subtractor.shift_left(cur_bit);
 
     for (; cur_bit >= 0 && *this > 0; --cur_bit, subtractor.shift_right(1)) {
       if (*this >= subtractor) {
         this->sub(subtractor);
-        quotient = quotient | (BigInt<Bits, Signed>(1) << cur_bit);
+        quotient = quotient | (BigInt<Bits, Signed, WordType>(1) << cur_bit);
       }
     }
     remainder = *this;
@@ -419,8 +424,8 @@ template <size_t Bits, bool Signed> struct BigInt {
     return remainder;
   }
 
-  // Efficiently perform BigInt / (x * 2^e), where x is a 32-bit unsigned
-  // integer, and return the remainder. The main idea is as follow:
+  // Efficiently perform BigInt / (x * 2^e), where x is a half-word-size
+  // unsigned integer, and return the remainder. The main idea is as follow:
   //   Let q = y / (x * 2^e) be the quotient, and
   //       r = y % (x * 2^e) be the remainder.
   //   First, notice that:
@@ -428,102 +433,114 @@ template <size_t Bits, bool Signed> struct BigInt {
   // so we just need to focus on all the bits of y that is >= 2^e.
   //   To speed up the shift-and-add steps, we only use x as the divisor, and
   // performing 32-bit shiftings instead of bit-by-bit shiftings.
-  //   Since the remainder of each division step < x < 2^32, the computation of
-  // each step is now properly contained within uint64_t.
+  //   Since the remainder of each division step < x < 2^(WORD_SIZE / 2), the
+  // computation of each step is now properly contained within WordType.
   //   And finally we perform some extra alignment steps for the remaining bits.
-  LIBC_INLINE constexpr optional<BigInt<Bits, Signed>>
-  div_uint32_times_pow_2(uint32_t x, size_t e) {
-    BigInt<Bits, Signed> remainder(0);
+  // template <typename T>
+  // LIBC_INLINE constexpr cpp::enable_if_t<
+  //     cpp::is_integral_v<T> && cpp::is_unsigned_v<T> &&
+  //         (sizeof(T) * CHAR_BIT == WORD_SIZE / 2),
+  //     optional<BigInt<Bits, Signed, WordType>>>
+  LIBC_INLINE constexpr optional<BigInt<Bits, Signed, WordType>>
+  div_uint_half_times_pow_2(uint32_t x, size_t e) {
+    BigInt<Bits, Signed, WordType> remainder(0);
 
     if (x == 0) {
       return nullopt;
     }
     if (e >= Bits) {
       remainder = *this;
-      *this = BigInt<Bits, false>(0);
+      *this = BigInt<Bits, false, WordType>(0);
       return remainder;
     }
 
-    BigInt<Bits, Signed> quotient(0);
-    uint64_t x64 = static_cast<uint64_t>(x);
-    // lower64 = smallest multiple of 64 that is >= e.
-    size_t lower64 = ((e >> 6) + ((e & 63) != 0)) << 6;
-    // lower_pos is the index of the closest 64-bit chunk >= 2^e.
-    size_t lower_pos = lower64 / 64;
+    BigInt<Bits, Signed, WordType> quotient(0);
+    WordType x64 = static_cast<WordType>(x);
+    constexpr size_t LOG2_WORD_SIZE = bit_width(WORD_SIZE) - 1;
+    constexpr size_t HALF_WORD_SIZE = WORD_SIZE >> 1;
+    constexpr WordType HALF_MASK = ((WordType(1) << HALF_WORD_SIZE) - 1);
+    // lower = smallest multiple of WORD_SIZE that is >= e.
+    size_t lower = ((e >> LOG2_WORD_SIZE) + ((e & (WORD_SIZE - 1)) != 0))
+                   << LOG2_WORD_SIZE;
+    // lower_pos is the index of the closest WORD_SIZE-bit chunk >= 2^e.
+    size_t lower_pos = lower / WORD_SIZE;
     // Keep track of current remainder mod x * 2^(32*i)
     uint64_t rem = 0;
     // pos is the index of the current 64-bit chunk that we are processing.
-    size_t pos = WORDCOUNT;
+    size_t pos = WORD_COUNT;
 
     // TODO: look into if constexpr(Bits > 256) skip leading zeroes.
 
-    for (size_t q_pos = WORDCOUNT - lower_pos; q_pos > 0; --q_pos) {
-      // q_pos is 1 + the index of the current 64-bit chunk of the quotient
-      // being processed.
-      // Performing the division / modulus with divisor:
-      //   x * 2^(64*q_pos - 32),
-      // i.e. using the upper 32-bit of the current 64-bit chunk.
-      rem <<= 32;
-      rem += val[--pos] >> 32;
+    for (size_t q_pos = WORD_COUNT - lower_pos; q_pos > 0; --q_pos) {
+      // q_pos is 1 + the index of the current WORD_SIZE-bit chunk of the
+      // quotient being processed. Performing the division / modulus with
+      // divisor:
+      //   x * 2^(WORD_SIZE*q_pos - WORD_SIZE/2),
+      // i.e. using the upper (WORD_SIZE/2)-bit of the current WORD_SIZE-bit
+      // chunk.
+      rem <<= HALF_WORD_SIZE;
+      rem += val[--pos] >> HALF_WORD_SIZE;
       uint64_t q_tmp = rem / x64;
       rem %= x64;
 
       // Performing the division / modulus with divisor:
-      //   x * 2^(64*(q_pos - 1)),
-      // i.e. using the lower 32-bit of the current 64-bit chunk.
-      rem <<= 32;
-      rem += val[pos] & MASK32;
-      quotient.val[q_pos - 1] = (q_tmp << 32) + rem / x64;
+      //   x * 2^(WORD_SIZE*(q_pos - 1)),
+      // i.e. using the lower (WORD_SIZE/2)-bit of the current WORD_SIZE-bit
+      // chunk.
+      rem <<= HALF_WORD_SIZE;
+      rem += val[pos] & HALF_MASK;
+      quotient.val[q_pos - 1] = (q_tmp << HALF_WORD_SIZE) + rem / x64;
       rem %= x64;
     }
 
     // So far, what we have is:
-    //   quotient = y / (x * 2^lower64), and
-    //        rem = (y % (x * 2^lower64)) / 2^lower64.
-    // If (lower64 > e), we will need to perform an extra adjustment of the
+    //   quotient = y / (x * 2^lower), and
+    //        rem = (y % (x * 2^lower)) / 2^lower.
+    // If (lower > e), we will need to perform an extra adjustment of the
     // quotient and remainder, namely:
-    //   y / (x * 2^e) = [ y / (x * 2^lower64) ] * 2^(lower64 - e) +
-    //                   + (rem * 2^(lower64 - e)) / x
-    //   (y % (x * 2^e)) / 2^e = (rem * 2^(lower64 - e)) % x
-    size_t last_shift = lower64 - e;
+    //   y / (x * 2^e) = [ y / (x * 2^lower) ] * 2^(lower - e) +
+    //                   + (rem * 2^(lower - e)) / x
+    //   (y % (x * 2^e)) / 2^e = (rem * 2^(lower - e)) % x
+    size_t last_shift = lower - e;
 
     if (last_shift > 0) {
-      // quotient * 2^(lower64 - e)
+      // quotient * 2^(lower - e)
       quotient <<= last_shift;
-      uint64_t q_tmp = 0;
-      uint64_t d = val[--pos];
-      if (last_shift >= 32) {
-        // The shifting (rem * 2^(lower64 - e)) might overflow uint64_t, so we
-        // perform a 32-bit shift first.
-        rem <<= 32;
-        rem += d >> 32;
-        d &= MASK32;
+      WordType q_tmp = 0;
+      WordType d = val[--pos];
+      if (last_shift >= HALF_WORD_SIZE) {
+        // The shifting (rem * 2^(lower - e)) might overflow WordTyoe, so we
+        // perform a HALF_WORD_SIZE-bit shift first.
+        rem <<= HALF_WORD_SIZE;
+        rem += d >> HALF_WORD_SIZE;
+        d &= HALF_MASK;
         q_tmp = rem / x64;
         rem %= x64;
-        last_shift -= 32;
+        last_shift -= HALF_WORD_SIZE;
       } else {
-        // Only use the upper 32-bit of the current 64-bit chunk.
-        d >>= 32;
+        // Only use the upper HALF_WORD_SIZE-bit of the current WORD_SIZE-bit
+        // chunk.
+        d >>= HALF_WORD_SIZE;
       }
 
       if (last_shift > 0) {
-        rem <<= 32;
+        rem <<= HALF_WORD_SIZE;
         rem += d;
         q_tmp <<= last_shift;
-        x64 <<= 32 - last_shift;
+        x64 <<= HALF_WORD_SIZE - last_shift;
         q_tmp += rem / x64;
         rem %= x64;
       }
 
       quotient.val[0] += q_tmp;
 
-      if (lower64 - e <= 32) {
-        // The remainder rem * 2^(lower64 - e) might overflow to the higher
-        // 64-bit chunk.
-        if (pos < WORDCOUNT - 1) {
-          remainder[pos + 1] = rem >> 32;
+      if (lower - e <= HALF_WORD_SIZE) {
+        // The remainder rem * 2^(lower - e) might overflow to the higher
+        // WORD_SIZE-bit chunk.
+        if (pos < WORD_COUNT - 1) {
+          remainder[pos + 1] = rem >> HALF_WORD_SIZE;
         }
-        remainder[pos] = (rem << 32) + (val[pos] & MASK32);
+        remainder[pos] = (rem << HALF_WORD_SIZE) + (val[pos] & HALF_MASK);
       } else {
         remainder[pos] = rem;
       }
@@ -541,36 +558,36 @@ template <size_t Bits, bool Signed> struct BigInt {
     return remainder;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator/(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result(*this);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator/(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result(*this);
     result.div(other);
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator/=(const BigInt<Bits, Signed> &other) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator/=(const BigInt<Bits, Signed, WordType> &other) {
     div(other);
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator%(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result(*this);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator%(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result(*this);
     return *result.div(other);
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator*=(const BigInt<Bits, Signed> &other) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator*=(const BigInt<Bits, Signed, WordType> &other) {
     *this = *this * other;
     return *this;
   }
 
   LIBC_INLINE constexpr uint64_t clz() {
     uint64_t leading_zeroes = 0;
-    for (size_t i = WORDCOUNT; i > 0; --i) {
+    for (size_t i = WORD_COUNT; i > 0; --i) {
       if (val[i - 1] == 0) {
-        leading_zeroes += sizeof(uint64_t) * 8;
+        leading_zeroes += WORD_SIZE;
       } else {
         leading_zeroes += countl_zero(val[i - 1]);
         break;
@@ -580,8 +597,30 @@ template <size_t Bits, bool Signed> struct BigInt {
   }
 
   LIBC_INLINE constexpr void shift_left(size_t s) {
+    if constexpr (Bits == WORD_SIZE) {
+      // Use native types if possible.
+      if (s >= WORD_SIZE) {
+        val[0] = 0;
+        return;
+      }
+      val[0] <<= s;
+      return;
+    }
+    if constexpr ((Bits == 64) && (WORD_SIZE == 32)) {
+      // Use builtin 64 bits for 32-bit base type if available;
+      if (s >= 64) {
+        val[0] = 0;
+        val[1] = 0;
+        return;
+      }
+      uint64_t tmp = uint64__t(val[0]) + (uint64_t(val[1]) << 62);
+      tmp <<= s;
+      val[0] = uint32_t(tmp);
+      val[1] = uint32_t(tmp >> 32);
+      return;
+    }
 #ifdef __SIZEOF_INT128__
-    if constexpr (Bits == 128) {
+    if constexpr ((Bits == 128) && (WORD_SIZE == 64)) {
       // Use builtin 128 bits if available;
       if (s >= 128) {
         val[0] = 0;
@@ -598,19 +637,19 @@ template <size_t Bits, bool Signed> struct BigInt {
     if (LIBC_UNLIKELY(s == 0))
       return;
 
-    const size_t drop = s / 64;  // Number of words to drop
-    const size_t shift = s % 64; // Bits to shift in the remaining words.
-    size_t i = WORDCOUNT;
+    const size_t drop = s / WORD_SIZE;  // Number of words to drop
+    const size_t shift = s % WORD_SIZE; // Bits to shift in the remaining words.
+    size_t i = WORD_COUNT;
 
-    if (drop < WORDCOUNT) {
-      i = WORDCOUNT - 1;
+    if (drop < WORD_COUNT) {
+      i = WORD_COUNT - 1;
       if (shift > 0) {
-        for (size_t j = WORDCOUNT - 1 - drop; j > 0; --i, --j) {
-          val[i] = (val[j] << shift) | (val[j - 1] >> (64 - shift));
+        for (size_t j = WORD_COUNT - 1 - drop; j > 0; --i, --j) {
+          val[i] = (val[j] << shift) | (val[j - 1] >> (WORD_SIZE - shift));
         }
         val[i] = val[0] << shift;
       } else {
-        for (size_t j = WORDCOUNT - 1 - drop; j > 0; --i, --j) {
+        for (size_t j = WORD_COUNT - 1 - drop; j > 0; --i, --j) {
           val[i] = val[j];
         }
         val[i] = val[0];
@@ -622,20 +661,38 @@ template <size_t Bits, bool Signed> struct BigInt {
     }
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> operator<<(size_t s) const {
-    BigInt<Bits, Signed> result(*this);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator<<(size_t s) const {
+    BigInt<Bits, Signed, WordType> result(*this);
     result.shift_left(s);
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &operator<<=(size_t s) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &operator<<=(size_t s) {
     shift_left(s);
     return *this;
   }
 
   LIBC_INLINE constexpr void shift_right(size_t s) {
+    if constexpr ((Bits == 64) && (WORD_SIZE == 32)) {
+      // Use builtin 64 bits if available;
+      if (s >= 64) {
+        val[0] = 0;
+        val[1] = 0;
+        return;
+      }
+      uint64_t tmp = uint64_t(val[0]) + (uint64_t(val[1]) << 32);
+      if constexpr (Signed) {
+        tmp = static_cast<uint64_t>(static_cast<int64_t>(tmp) >> s);
+      } else {
+        tmp >>= s;
+      }
+      val[0] = uint32_t(tmp);
+      val[1] = uint32_t(tmp >> 32);
+      return;
+    }
 #ifdef __SIZEOF_INT128__
-    if constexpr (Bits == 128) {
+    if constexpr ((Bits == 128) && (WORD_SIZE == 64)) {
       // Use builtin 128 bits if available;
       if (s >= 128) {
         val[0] = 0;
@@ -656,108 +713,110 @@ template <size_t Bits, bool Signed> struct BigInt {
 
     if (LIBC_UNLIKELY(s == 0))
       return;
-    const size_t drop = s / 64;  // Number of words to drop
-    const size_t shift = s % 64; // Bit shift in the remaining words.
+    const size_t drop = s / WORD_SIZE;  // Number of words to drop
+    const size_t shift = s % WORD_SIZE; // Bit shift in the remaining words.
 
     size_t i = 0;
-    uint64_t sign = Signed ? (val[WORDCOUNT - 1] >> 63) : 0;
+    uint64_t sign = Signed ? (val[WORD_COUNT - 1] >> (WORD_SIZE - 1)) : 0;
 
-    if (drop < WORDCOUNT) {
+    if (drop < WORD_COUNT) {
       if (shift > 0) {
-        for (size_t j = drop; j < WORDCOUNT - 1; ++i, ++j) {
-          val[i] = (val[j] >> shift) | (val[j + 1] << (64 - shift));
+        for (size_t j = drop; j < WORD_COUNT - 1; ++i, ++j) {
+          val[i] = (val[j] >> shift) | (val[j + 1] << (WORD_SIZE - shift));
         }
         if constexpr (Signed) {
-          val[i] = static_cast<uint64_t>(
-              static_cast<int64_t>(val[WORDCOUNT - 1]) >> shift);
+          val[i] = static_cast<WordType>(
+              static_cast<cpp::make_signed_t<WordType>>(val[WORD_COUNT - 1]) >>
+              shift);
         } else {
-          val[i] = val[WORDCOUNT - 1] >> shift;
+          val[i] = val[WORD_COUNT - 1] >> shift;
         }
         ++i;
       } else {
-        for (size_t j = drop; j < WORDCOUNT; ++i, ++j) {
+        for (size_t j = drop; j < WORD_COUNT; ++i, ++j) {
           val[i] = val[j];
         }
       }
     }
 
-    for (; i < WORDCOUNT; ++i) {
+    for (; i < WORD_COUNT; ++i) {
       val[i] = sign;
     }
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> operator>>(size_t s) const {
-    BigInt<Bits, Signed> result(*this);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator>>(size_t s) const {
+    BigInt<Bits, Signed, WordType> result(*this);
     result.shift_right(s);
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &operator>>=(size_t s) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &operator>>=(size_t s) {
     shift_right(s);
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator&(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator&(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       result.val[i] = val[i] & other.val[i];
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator&=(const BigInt<Bits, Signed> &other) {
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator&=(const BigInt<Bits, Signed, WordType> &other) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       val[i] &= other.val[i];
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator|(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator|(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       result.val[i] = val[i] | other.val[i];
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator|=(const BigInt<Bits, Signed> &other) {
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator|=(const BigInt<Bits, Signed, WordType> &other) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       val[i] |= other.val[i];
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator^(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator^(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       result.val[i] = val[i] ^ other.val[i];
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator^=(const BigInt<Bits, Signed> &other) {
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator^=(const BigInt<Bits, Signed, WordType> &other) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       val[i] ^= other.val[i];
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> operator~() const {
-    BigInt<Bits, Signed> result;
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> operator~() const {
+    BigInt<Bits, Signed, WordType> result;
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       result.val[i] = ~val[i];
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> operator-() const {
-    BigInt<Bits, Signed> result = ~(*this);
-    result.add(BigInt<Bits, Signed>(1));
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> operator-() const {
+    BigInt<Bits, Signed, WordType> result = ~(*this);
+    result.add(BigInt<Bits, Signed, WordType>(1));
     return result;
   }
 
   LIBC_INLINE constexpr bool
-  operator==(const BigInt<Bits, Signed> &other) const {
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  operator==(const BigInt<Bits, Signed, WordType> &other) const {
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       if (val[i] != other.val[i])
         return false;
     }
@@ -765,8 +824,8 @@ template <size_t Bits, bool Signed> struct BigInt {
   }
 
   LIBC_INLINE constexpr bool
-  operator!=(const BigInt<Bits, Signed> &other) const {
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  operator!=(const BigInt<Bits, Signed, WordType> &other) const {
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       if (val[i] != other.val[i])
         return true;
     }
@@ -774,18 +833,18 @@ template <size_t Bits, bool Signed> struct BigInt {
   }
 
   LIBC_INLINE constexpr bool
-  operator>(const BigInt<Bits, Signed> &other) const {
+  operator>(const BigInt<Bits, Signed, WordType> &other) const {
     if constexpr (Signed) {
       // Check for different signs;
-      bool a_sign = val[WORDCOUNT - 1] >> 63;
-      bool b_sign = other.val[WORDCOUNT - 1] >> 63;
+      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
+      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
       if (a_sign != b_sign) {
         return b_sign;
       }
     }
-    for (size_t i = WORDCOUNT; i > 0; --i) {
-      uint64_t word = val[i - 1];
-      uint64_t other_word = other.val[i - 1];
+    for (size_t i = WORD_COUNT; i > 0; --i) {
+      WordType word = val[i - 1];
+      WordType other_word = other.val[i - 1];
       if (word > other_word)
         return true;
       else if (word < other_word)
@@ -796,18 +855,18 @@ template <size_t Bits, bool Signed> struct BigInt {
   }
 
   LIBC_INLINE constexpr bool
-  operator>=(const BigInt<Bits, Signed> &other) const {
+  operator>=(const BigInt<Bits, Signed, WordType> &other) const {
     if constexpr (Signed) {
       // Check for different signs;
-      bool a_sign = val[WORDCOUNT - 1] >> 63;
-      bool b_sign = other.val[WORDCOUNT - 1] >> 63;
+      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
+      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
       if (a_sign != b_sign) {
         return b_sign;
       }
     }
-    for (size_t i = WORDCOUNT; i > 0; --i) {
-      uint64_t word = val[i - 1];
-      uint64_t other_word = other.val[i - 1];
+    for (size_t i = WORD_COUNT; i > 0; --i) {
+      WordType word = val[i - 1];
+      WordType other_word = other.val[i - 1];
       if (word > other_word)
         return true;
       else if (word < other_word)
@@ -818,19 +877,19 @@ template <size_t Bits, bool Signed> struct BigInt {
   }
 
   LIBC_INLINE constexpr bool
-  operator<(const BigInt<Bits, Signed> &other) const {
+  operator<(const BigInt<Bits, Signed, WordType> &other) const {
     if constexpr (Signed) {
       // Check for different signs;
-      bool a_sign = val[WORDCOUNT - 1] >> 63;
-      bool b_sign = other.val[WORDCOUNT - 1] >> 63;
+      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
+      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
       if (a_sign != b_sign) {
         return a_sign;
       }
     }
 
-    for (size_t i = WORDCOUNT; i > 0; --i) {
-      uint64_t word = val[i - 1];
-      uint64_t other_word = other.val[i - 1];
+    for (size_t i = WORD_COUNT; i > 0; --i) {
+      WordType word = val[i - 1];
+      WordType other_word = other.val[i - 1];
       if (word > other_word)
         return false;
       else if (word < other_word)
@@ -841,18 +900,18 @@ template <size_t Bits, bool Signed> struct BigInt {
   }
 
   LIBC_INLINE constexpr bool
-  operator<=(const BigInt<Bits, Signed> &other) const {
+  operator<=(const BigInt<Bits, Signed, WordType> &other) const {
     if constexpr (Signed) {
       // Check for different signs;
-      bool a_sign = val[WORDCOUNT - 1] >> 63;
-      bool b_sign = other.val[WORDCOUNT - 1] >> 63;
+      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
+      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
       if (a_sign != b_sign) {
         return a_sign;
       }
     }
-    for (size_t i = WORDCOUNT; i > 0; --i) {
-      uint64_t word = val[i - 1];
-      uint64_t other_word = other.val[i - 1];
+    for (size_t i = WORD_COUNT; i > 0; --i) {
+      WordType word = val[i - 1];
+      WordType other_word = other.val[i - 1];
       if (word > other_word)
         return false;
       else if (word < other_word)
@@ -862,48 +921,53 @@ template <size_t Bits, bool Signed> struct BigInt {
     return true;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &operator++() {
-    BigInt<Bits, Signed> one(1);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &operator++() {
+    BigInt<Bits, Signed, WordType> one(1);
     add(one);
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> operator++(int) {
-    BigInt<Bits, Signed> oldval(*this);
-    BigInt<Bits, Signed> one(1);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> operator++(int) {
+    BigInt<Bits, Signed, WordType> oldval(*this);
+    BigInt<Bits, Signed, WordType> one(1);
     add(one);
     return oldval;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &operator--() {
-    BigInt<Bits, Signed> one(1);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &operator--() {
+    BigInt<Bits, Signed, WordType> one(1);
     sub(one);
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> operator--(int) {
-    BigInt<Bits, Signed> oldval(*this);
-    BigInt<Bits, Signed> one(1);
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> operator--(int) {
+    BigInt<Bits, Signed, WordType> oldval(*this);
+    BigInt<Bits, Signed, WordType> one(1);
     sub(one);
     return oldval;
   }
 
   // Return the i-th 64-bit word of the number.
-  LIBC_INLINE constexpr const uint64_t &operator[](size_t i) const {
+  LIBC_INLINE constexpr const WordType &operator[](size_t i) const {
     return val[i];
   }
 
   // Return the i-th 64-bit word of the number.
-  LIBC_INLINE constexpr uint64_t &operator[](size_t i) { return val[i]; }
+  LIBC_INLINE constexpr WordType &operator[](size_t i) { return val[i]; }
 
-  LIBC_INLINE uint64_t *data() { return val; }
+  LIBC_INLINE WordType *data() { return val; }
 
-  LIBC_INLINE const uint64_t *data() const { return val; }
+  LIBC_INLINE const WordType *data() const { return val; }
 };
 
-template <size_t Bits> using UInt = BigInt<Bits, false>;
+template <size_t Bits>
+using UInt =
+    typename cpp::conditional_t<Bits == 32, BigInt<32, false, uint32_t>,
+                                BigInt<Bits, false, uint64_t>>;
 
-template <size_t Bits> using Int = BigInt<Bits, true>;
+template <size_t Bits>
+using Int = typename cpp::conditional_t<Bits == 32, BigInt<32, true, uint32_t>,
+                                        BigInt<Bits, true, uint64_t>>;
 
 // Provides limits of U/Int<128>.
 template <> class numeric_limits<UInt<128>> {
@@ -927,45 +991,26 @@ template <> class numeric_limits<Int<128>> {
 };
 
 // Provides is_integral of U/Int<128>, U/Int<192>, U/Int<256>.
-template <size_t Bits, bool Signed>
-struct is_integral<BigInt<Bits, Signed>> : cpp::true_type {
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in BigInt should be a multiple of 64.");
-};
+template <size_t Bits, bool Signed, typename T>
+struct is_integral<BigInt<Bits, Signed, T>> : cpp::true_type {};
 
 // Provides is_unsigned of UInt<128>, UInt<192>, UInt<256>.
-template <size_t Bits> struct is_unsigned<UInt<Bits>> : public cpp::true_type {
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in UInt should be a multiple of 64.");
-};
-
-template <size_t Bits>
-struct make_unsigned<Int<Bits>> : type_identity<UInt<Bits>> {
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in Int should be a multiple of 64.");
-};
-
-template <size_t Bits>
-struct make_unsigned<UInt<Bits>> : type_identity<UInt<Bits>> {
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in Int should be a multiple of 64.");
-};
+template <size_t Bits, bool Signed, typename T>
+struct is_unsigned<BigInt<Bits, Signed, T>> : cpp::bool_constant<!Signed> {};
 
-template <size_t Bits>
-struct make_signed<Int<Bits>> : type_identity<Int<Bits>> {
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in Int should be a multiple of 64.");
-};
+template <size_t Bits, bool Signed, typename T>
+struct make_unsigned<BigInt<Bits, Signed, T>>
+    : type_identity<BigInt<Bits, false, T>> {};
 
-template <size_t Bits>
-struct make_signed<UInt<Bits>> : type_identity<Int<Bits>> {
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in Int should be a multiple of 64.");
-};
+template <size_t Bits, bool Signed, typename T>
+struct make_signed<BigInt<Bits, Signed, T>>
+    : type_identity<BigInt<Bits, true, T>> {};
 
 namespace internal {
 template <typename T> struct is_custom_uint : cpp::false_type {};
-template <size_t Bits> struct is_custom_uint<UInt<Bits>> : cpp::true_type {};
+
+template <size_t Bits, bool Signed, typename T>
+struct is_custom_uint<BigInt<Bits, Signed, T>> : cpp::true_type {};
 } // namespace internal
 
 // bit_cast to UInt
diff --git a/libc/src/__support/float_to_string.h b/libc/src/__support/float_to_string.h
index f30110d47b2192..83b68c936b27a9 100644
--- a/libc/src/__support/float_to_string.h
+++ b/libc/src/__support/float_to_string.h
@@ -208,7 +208,7 @@ LIBC_INLINE constexpr cpp::UInt<MID_INT_SIZE> get_table_positive(int exponent,
 
   num = num + 1;
   if (num > MOD_SIZE) {
-    auto rem = num.div_uint32_times_pow_2(
+    auto rem = num.div_uint_half_times_pow_2(
                       EXP10_9, CALC_SHIFT_CONST + (IDX_SIZE > 1 ? IDX_SIZE : 0))
                    .value();
     num = rem;
@@ -255,8 +255,8 @@ LIBC_INLINE cpp::UInt<MID_INT_SIZE> get_table_positive_df(int exponent,
   if (int_num > MOD_SIZE) {
     auto rem =
         int_num
-            .div_uint32_times_pow_2(EXP10_9, CALC_SHIFT_CONST +
-                                                 (IDX_SIZE > 1 ? IDX_SIZE : 0))
+            .div_uint_half_times_pow_2(
+                EXP10_9, CALC_SHIFT_CONST + (IDX_SIZE > 1 ? IDX_SIZE : 0))
             .value();
     int_num = rem;
   }
@@ -318,7 +318,7 @@ LIBC_INLINE cpp::UInt<MID_INT_SIZE> get_table_negative(int exponent, size_t i) {
     num = num >> (-shift_amount);
   }
   if (num > MOD_SIZE) {
-    auto rem = num.div_uint32_times_pow_2(
+    auto rem = num.div_uint_half_times_pow_2(
                       EXP10_9, CALC_SHIFT_CONST + (IDX_SIZE > 1 ? IDX_SIZE : 0))
                    .value();
     num = rem;
@@ -360,8 +360,8 @@ LIBC_INLINE cpp::UInt<MID_INT_SIZE> get_table_negative_df(int exponent,
   if (int_num > MOD_SIZE) {
     auto rem =
         int_num
-            .div_uint32_times_pow_2(EXP10_9, CALC_SHIFT_CONST +
-                                                 (IDX_SIZE > 1 ? IDX_SIZE : 0))
+            .div_uint_half_times_pow_2(
+                EXP10_9, CALC_SHIFT_CONST + (IDX_SIZE > 1 ? IDX_SIZE : 0))
             .value();
     int_num = rem;
   }
@@ -389,7 +389,8 @@ LIBC_INLINE uint32_t mul_shift_mod_1e9(const FPBits::StorageType mantissa,
                                        const int32_t shift_amount) {
   cpp::UInt<MID_INT_SIZE + FPBits::STORAGE_LEN> val(large);
   val = (val * mantissa) >> shift_amount;
-  return static_cast<uint32_t>(val.div_uint32_times_pow_2(EXP10_9, 0).value());
+  return static_cast<uint32_t>(
+      val.div_uint_half_times_pow_2(static_cast<uint32_t>(EXP10_9), 0).value());
 }
 
 } // namespace internal
@@ -658,7 +659,7 @@ template <> class FloatToString<long double> {
 
   template <size_t Bits>
   LIBC_INLINE static constexpr BlockInt grab_digits(cpp::UInt<Bits> &int_num) {
-    auto wide_result = int_num.div_uint32_times_pow_2(EXP5_9, 9);
+    auto wide_result = int_num.div_uint_half_times_pow_2(EXP5_9, 9);
     // the optional only comes into effect when dividing by 0, which will
     // never happen here. Thus, we just assert that it has value.
     LIBC_ASSERT(wide_result.has_value());
@@ -695,7 +696,8 @@ template <> class FloatToString<long double> {
 
       while (float_as_int > 0) {
         LIBC_ASSERT(int_block_index < static_cast<int>(BLOCK_BUFFER_LEN));
-        block_buffer[int_block_index] = grab_digits(float_as_int);
+        block_buffer[int_block_index] =
+            grab_digits<FLOAT_AS_INT_WIDTH + EXTRA_INT_WIDTH>(float_as_int);
         ++int_block_index;
       }
       block_buffer_valid = int_block_index;
@@ -718,7 +720,7 @@ template <> class FloatToString<long double> {
         size_t positive_int_block_index = 0;
         while (above_decimal_point > 0) {
           block_buffer[positive_int_block_index] =
-              grab_digits(above_decimal_point);
+              grab_digits<EXTRA_INT_WIDTH>(above_decimal_point);
           ++positive_int_block_index;
         }
         block_buffer_valid = positive_int_block_index;
diff --git a/libc/test/src/__support/uint_test.cpp b/libc/test/src/__support/uint_test.cpp
index 0ad72c35645c4b..9d85e341475638 100644
--- a/libc/test/src/__support/uint_test.cpp
+++ b/libc/test/src/__support/uint_test.cpp
@@ -588,7 +588,7 @@ TEST(LlvmLibcUIntClassTest, ConstexprInitTests) {
     d <<= e;                                                                   \
     LL_UInt320 q1 = y / d;                                                     \
     LL_UInt320 r1 = y % d;                                                     \
-    LL_UInt320 r2 = *y.div_uint32_times_pow_2(x, e);                           \
+    LL_UInt320 r2 = *y.div_uint_half_times_pow_2(x, e);                        \
     EXPECT_EQ(q1, y);                                                          \
     EXPECT_EQ(r1, r2);                                                         \
   } while (0)
@@ -678,4 +678,19 @@ TEST(LlvmLibcUIntClassTest, ConstructorFromUInt128Tests) {
 
 #endif // __SIZEOF_INT128__
 
+TEST(LlvmLibcUIntClassTest, OtherWordTypeTests) {
+  using LL_UInt96 = cpp::BigInt<96, false, uint32_t>;
+
+  LL_UInt96 a(1);
+
+  ASSERT_EQ(static_cast<int>(a), 1);
+  a = (a << 32) + 2;
+  ASSERT_EQ(static_cast<int>(a), 2);
+  ASSERT_EQ(static_cast<uint64_t>(a), uint64_t(0x1'0000'0002));
+  a = (a << 32) + 3;
+  ASSERT_EQ(static_cast<int>(a), 3);
+  ASSERT_EQ(static_cast<int>(a >> 32), 2);
+  ASSERT_EQ(static_cast<int>(a >> 64), 1);
+}
+
 } // namespace LIBC_NAMESPACE

>From 4112654c12bbfad40a959d76a17008beeeef0ec0 Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue at google.com>
Date: Tue, 13 Feb 2024 15:30:23 -0500
Subject: [PATCH 2/2] Make full_mul and div_uint_half_times_pow_2 work properly
 for various base types.

---
 libc/src/__support/UInt.h             | 50 ++++++++++++++++-----------
 libc/src/__support/integer_utils.h    | 46 ++++++++++++------------
 libc/test/src/__support/uint_test.cpp | 31 +++++++++++++++++
 3 files changed, 84 insertions(+), 43 deletions(-)

diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h
index 5a60ea0e6d8135..0828a34ba1a934 100644
--- a/libc/src/__support/UInt.h
+++ b/libc/src/__support/UInt.h
@@ -25,6 +25,19 @@
 
 namespace LIBC_NAMESPACE::cpp {
 
+namespace internal {
+template <typename T> struct half_width;
+
+template <> struct half_width<uint64_t> : type_identity<uint32_t> {};
+template <> struct half_width<uint32_t> : type_identity<uint16_t> {};
+template <> struct half_width<uint16_t> : type_identity<uint8_t> {};
+#ifdef __SIZEOF_INT128__
+template <> struct half_width<__uint128_t> : type_identity<uint64_t> {};
+#endif // __SIZEOF_INT128__
+
+template <typename T> using half_width_t = typename half_width<T>::type;
+} // namespace internal
+
 template <size_t Bits, bool Signed, typename WordType = uint64_t>
 struct BigInt {
   static_assert(is_integral_v<WordType> && is_unsigned_v<WordType>,
@@ -54,10 +67,10 @@ struct BigInt {
       size_t i = 0;
       for (; i < OtherBits / 64; ++i)
         val[i] = other[i];
-      uint64_t sign = 0;
+      WordType sign = 0;
       if constexpr (Signed && OtherSigned) {
-        sign = static_cast<uint64_t>(
-            -static_cast<int64_t>(other[OtherBits / 64 - 1] >> 63));
+        sign = static_cast<WordType>(-static_cast<make_signed_t<WordType>>(
+            other[OtherBits / WORD_SIZE - 1] >> (WORD_SIZE - 1)));
       }
       for (; i < WORD_COUNT; ++i)
         val[i] = sign;
@@ -436,13 +449,8 @@ struct BigInt {
   //   Since the remainder of each division step < x < 2^(WORD_SIZE / 2), the
   // computation of each step is now properly contained within WordType.
   //   And finally we perform some extra alignment steps for the remaining bits.
-  // template <typename T>
-  // LIBC_INLINE constexpr cpp::enable_if_t<
-  //     cpp::is_integral_v<T> && cpp::is_unsigned_v<T> &&
-  //         (sizeof(T) * CHAR_BIT == WORD_SIZE / 2),
-  //     optional<BigInt<Bits, Signed, WordType>>>
   LIBC_INLINE constexpr optional<BigInt<Bits, Signed, WordType>>
-  div_uint_half_times_pow_2(uint32_t x, size_t e) {
+  div_uint_half_times_pow_2(internal::half_width_t<WordType> x, size_t e) {
     BigInt<Bits, Signed, WordType> remainder(0);
 
     if (x == 0) {
@@ -455,7 +463,7 @@ struct BigInt {
     }
 
     BigInt<Bits, Signed, WordType> quotient(0);
-    WordType x64 = static_cast<WordType>(x);
+    WordType x_word = static_cast<WordType>(x);
     constexpr size_t LOG2_WORD_SIZE = bit_width(WORD_SIZE) - 1;
     constexpr size_t HALF_WORD_SIZE = WORD_SIZE >> 1;
     constexpr WordType HALF_MASK = ((WordType(1) << HALF_WORD_SIZE) - 1);
@@ -465,7 +473,7 @@ struct BigInt {
     // lower_pos is the index of the closest WORD_SIZE-bit chunk >= 2^e.
     size_t lower_pos = lower / WORD_SIZE;
     // Keep track of current remainder mod x * 2^(32*i)
-    uint64_t rem = 0;
+    WordType rem = 0;
     // pos is the index of the current 64-bit chunk that we are processing.
     size_t pos = WORD_COUNT;
 
@@ -480,8 +488,8 @@ struct BigInt {
       // chunk.
       rem <<= HALF_WORD_SIZE;
       rem += val[--pos] >> HALF_WORD_SIZE;
-      uint64_t q_tmp = rem / x64;
-      rem %= x64;
+      WordType q_tmp = rem / x_word;
+      rem %= x_word;
 
       // Performing the division / modulus with divisor:
       //   x * 2^(WORD_SIZE*(q_pos - 1)),
@@ -489,8 +497,8 @@ struct BigInt {
       // chunk.
       rem <<= HALF_WORD_SIZE;
       rem += val[pos] & HALF_MASK;
-      quotient.val[q_pos - 1] = (q_tmp << HALF_WORD_SIZE) + rem / x64;
-      rem %= x64;
+      quotient.val[q_pos - 1] = (q_tmp << HALF_WORD_SIZE) + rem / x_word;
+      rem %= x_word;
     }
 
     // So far, what we have is:
@@ -514,8 +522,8 @@ struct BigInt {
         rem <<= HALF_WORD_SIZE;
         rem += d >> HALF_WORD_SIZE;
         d &= HALF_MASK;
-        q_tmp = rem / x64;
-        rem %= x64;
+        q_tmp = rem / x_word;
+        rem %= x_word;
         last_shift -= HALF_WORD_SIZE;
       } else {
         // Only use the upper HALF_WORD_SIZE-bit of the current WORD_SIZE-bit
@@ -527,9 +535,9 @@ struct BigInt {
         rem <<= HALF_WORD_SIZE;
         rem += d;
         q_tmp <<= last_shift;
-        x64 <<= HALF_WORD_SIZE - last_shift;
-        q_tmp += rem / x64;
-        rem %= x64;
+        x_word <<= HALF_WORD_SIZE - last_shift;
+        q_tmp += rem / x_word;
+        rem %= x_word;
       }
 
       quotient.val[0] += q_tmp;
@@ -717,7 +725,7 @@ struct BigInt {
     const size_t shift = s % WORD_SIZE; // Bit shift in the remaining words.
 
     size_t i = 0;
-    uint64_t sign = Signed ? (val[WORD_COUNT - 1] >> (WORD_SIZE - 1)) : 0;
+    WordType sign = Signed ? (val[WORD_COUNT - 1] >> (WORD_SIZE - 1)) : 0;
 
     if (drop < WORD_COUNT) {
       if (shift > 0) {
diff --git a/libc/src/__support/integer_utils.h b/libc/src/__support/integer_utils.h
index 1d9a134934cc55..dd407f9b2ef9a6 100644
--- a/libc/src/__support/integer_utils.h
+++ b/libc/src/__support/integer_utils.h
@@ -19,7 +19,28 @@
 
 namespace LIBC_NAMESPACE {
 
-template <typename T> NumberPair<T> full_mul(T a, T b);
+template <typename T> NumberPair<T> full_mul(T a, T b) {
+  NumberPair<T> pa = split(a);
+  NumberPair<T> pb = split(b);
+  NumberPair<T> prod;
+
+  prod.lo = pa.lo * pb.lo;                    // exact
+  prod.hi = pa.hi * pb.hi;                    // exact
+  NumberPair<T> lo_hi = split(pa.lo * pb.hi); // exact
+  NumberPair<T> hi_lo = split(pa.hi * pb.lo); // exact
+
+  constexpr size_t HALF_BIT_WIDTH = sizeof(T) * CHAR_BIT / 2;
+
+  auto r1 = add_with_carry(prod.lo, lo_hi.lo << HALF_BIT_WIDTH, T(0));
+  prod.lo = r1.sum;
+  prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
+
+  auto r2 = add_with_carry(prod.lo, hi_lo.lo << HALF_BIT_WIDTH, T(0));
+  prod.lo = r2.sum;
+  prod.hi = add_with_carry(prod.hi, hi_lo.hi, r2.carry).sum;
+
+  return prod;
+}
 
 template <>
 LIBC_INLINE NumberPair<uint32_t> full_mul<uint32_t>(uint32_t a, uint32_t b) {
@@ -30,35 +51,16 @@ LIBC_INLINE NumberPair<uint32_t> full_mul<uint32_t>(uint32_t a, uint32_t b) {
   return result;
 }
 
+#ifdef __SIZEOF_INT128__
 template <>
 LIBC_INLINE NumberPair<uint64_t> full_mul<uint64_t>(uint64_t a, uint64_t b) {
-#ifdef __SIZEOF_INT128__
   __uint128_t prod = __uint128_t(a) * __uint128_t(b);
   NumberPair<uint64_t> result;
   result.lo = uint64_t(prod);
   result.hi = uint64_t(prod >> 64);
   return result;
-#else
-  NumberPair<uint64_t> pa = split(a);
-  NumberPair<uint64_t> pb = split(b);
-  NumberPair<uint64_t> prod;
-
-  prod.lo = pa.lo * pb.lo;                           // exact
-  prod.hi = pa.hi * pb.hi;                           // exact
-  NumberPair<uint64_t> lo_hi = split(pa.lo * pb.hi); // exact
-  NumberPair<uint64_t> hi_lo = split(pa.hi * pb.lo); // exact
-
-  auto r1 = add_with_carry(prod.lo, lo_hi.lo << 32, uint64_t(0));
-  prod.lo = r1.sum;
-  prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
-
-  auto r2 = add_with_carry(prod.lo, hi_lo.lo << 32, uint64_t(0));
-  prod.lo = r2.sum;
-  prod.hi = add_with_carry(prod.hi, hi_lo.hi, r2.carry).sum;
-
-  return prod;
-#endif // __SIZEOF_INT128__
 }
+#endif // __SIZEOF_INT128__
 
 } // namespace LIBC_NAMESPACE
 
diff --git a/libc/test/src/__support/uint_test.cpp b/libc/test/src/__support/uint_test.cpp
index 9d85e341475638..1a1171b46781e8 100644
--- a/libc/test/src/__support/uint_test.cpp
+++ b/libc/test/src/__support/uint_test.cpp
@@ -676,6 +676,37 @@ TEST(LlvmLibcUIntClassTest, ConstructorFromUInt128Tests) {
   ASSERT_EQ(LL_UInt192(e + f), LL_UInt192(a + b));
 }
 
+TEST(LlvmLibcUIntClassTest, WordTypeUInt128Tests) {
+  using LL_UInt256_128 = cpp::BigInt<256, false, __uint128_t>;
+  using LL_UInt128_128 = cpp::BigInt<128, false, __uint128_t>;
+
+  LL_UInt256_128 a(1);
+
+  ASSERT_EQ(static_cast<int>(a), 1);
+  a = (a << 128) + 2;
+  ASSERT_EQ(static_cast<int>(a), 2);
+  ASSERT_EQ(static_cast<uint64_t>(a), uint64_t(2));
+  a = (a << 32) + 3;
+  ASSERT_EQ(static_cast<int>(a), 3);
+  ASSERT_EQ(static_cast<uint64_t>(a), uint64_t(0x2'0000'0003));
+  ASSERT_EQ(static_cast<int>(a >> 32), 2);
+  ASSERT_EQ(static_cast<int>(a >> (128 + 32)), 1);
+
+  LL_UInt128_128 b(__uint128_t(1) << 127);
+  LL_UInt128_128 c(b);
+  a = b.ful_mul(c);
+
+  ASSERT_EQ(static_cast<int>(a >> 254), 1);
+
+  LL_UInt256_128 d = LL_UInt256_128(123) << 4;
+  ASSERT_EQ(static_cast<int>(d), 123 << 4);
+  LL_UInt256_128 e = a / d;
+  LL_UInt256_128 f = a % d;
+  LL_UInt256_128 r = *a.div_uint_half_times_pow_2(123, 4);
+  EXPECT_TRUE(e == a);
+  EXPECT_TRUE(f == r);
+}
+
 #endif // __SIZEOF_INT128__
 
 TEST(LlvmLibcUIntClassTest, OtherWordTypeTests) {



More information about the libc-commits mailing list