[libc-commits] [libc] [llvm] [libc] Refactor `BigInt` (PR #86137)

Michael Jones via libc-commits libc-commits at lists.llvm.org
Fri Mar 22 11:38:48 PDT 2024


================
@@ -17,79 +17,293 @@
 #include "src/__support/macros/attributes.h"   // LIBC_INLINE
 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
 #include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128, LIBC_TYPES_HAS_INT64
-#include "src/__support/math_extras.h" // SumCarry, DiffBorrow
+#include "src/__support/math_extras.h" // add_with_carry, sub_with_borrow
 #include "src/__support/number_pair.h"
 
 #include <stddef.h> // For size_t
 #include <stdint.h>
 
 namespace LIBC_NAMESPACE {
 
-namespace internal {
-template <typename T> struct half_width;
+namespace multiword {
 
-template <> struct half_width<uint64_t> : cpp::type_identity<uint32_t> {};
-template <> struct half_width<uint32_t> : cpp::type_identity<uint16_t> {};
+// A type trait mapping unsigned integers to their half-width unsigned
+// counterparts.
+template <typename T> struct half_width;
 template <> struct half_width<uint16_t> : cpp::type_identity<uint8_t> {};
+template <> struct half_width<uint32_t> : cpp::type_identity<uint16_t> {};
+#ifdef LIBC_TYPES_HAS_INT64
+template <> struct half_width<uint64_t> : cpp::type_identity<uint32_t> {};
 #ifdef LIBC_TYPES_HAS_INT128
 template <> struct half_width<__uint128_t> : cpp::type_identity<uint64_t> {};
 #endif // LIBC_TYPES_HAS_INT128
-
+#endif // LIBC_TYPES_HAS_INT64
 template <typename T> using half_width_t = typename half_width<T>::type;
 
-template <typename T> constexpr NumberPair<T> full_mul(T a, T b) {
-  NumberPair<T> pa = split(a);
-  NumberPair<T> pb = split(b);
-  NumberPair<T> prod;
+// An array of two elements that can be used in multiword operations.
+template <typename T> struct Double final : cpp::array<T, 2> {
+  using UP = cpp::array<T, 2>;
+  using UP::UP;
+  LIBC_INLINE constexpr Double(T lo, T hi) : UP({lo, hi}) {}
+};
 
-  prod.lo = pa.lo * pb.lo;                    // exact
-  prod.hi = pa.hi * pb.hi;                    // exact
-  NumberPair<T> lo_hi = split(pa.lo * pb.hi); // exact
-  NumberPair<T> hi_lo = split(pa.hi * pb.lo); // exact
+// Converts an unsigned value into a Double<half_width_t<T>>.
+template <typename T> LIBC_INLINE constexpr auto split(T value) {
+  static_assert(cpp::is_unsigned_v<T>);
+  return cpp::bit_cast<Double<half_width_t<T>>>(value);
+}
 
-  constexpr size_t HALF_BIT_WIDTH = sizeof(T) * CHAR_BIT / 2;
+// The low part of a Double value.
+template <typename T> LIBC_INLINE constexpr T lo(const Double<T> &value) {
+  return value[0];
+}
+// The high part of a Double value.
+template <typename T> LIBC_INLINE constexpr T hi(const Double<T> &value) {
+  return value[1];
+}
+// The low part of an unsigned value.
+template <typename T> LIBC_INLINE constexpr half_width_t<T> lo(T value) {
+  return lo(split(value));
+}
+// The high part of an unsigned value.
+template <typename T> LIBC_INLINE constexpr half_width_t<T> hi(T value) {
+  return hi(split(value));
+}
 
-  auto r1 = add_with_carry(prod.lo, lo_hi.lo << HALF_BIT_WIDTH, T(0));
-  prod.lo = r1.sum;
-  prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
+// Returns 'a' times 'b' in a Double<word>. Cannot overflow by definition.
+template <typename word>
+LIBC_INLINE constexpr Double<word> mul2(word a, word b) {
+  if constexpr (cpp::is_same_v<word, uint8_t>) {
+    return split<uint16_t>(uint16_t(a) * uint16_t(b));
+  } else if constexpr (cpp::is_same_v<word, uint16_t>) {
+    return split<uint32_t>(uint32_t(a) * uint32_t(b));
+  }
+#ifdef LIBC_TYPES_HAS_INT64
+  else if constexpr (cpp::is_same_v<word, uint32_t>) {
+    return split<uint64_t>(uint64_t(a) * uint64_t(b));
+  }
+#endif
+#ifdef LIBC_TYPES_HAS_INT128
+  else if constexpr (cpp::is_same_v<word, uint64_t>) {
+    return split<__uint128_t>(__uint128_t(a) * __uint128_t(b));
+  }
+#endif
+  else {
+    using half_word = half_width_t<word>;
+    const auto shiftl = [](word value) -> word {
+      return value << cpp::numeric_limits<half_word>::digits;
+    };
+    const auto shiftr = [](word value) -> word {
+      return value >> cpp::numeric_limits<half_word>::digits;
+    };
+    // Here we do a one digit multiplication where 'a' and 'b' are of type
+    // word. We split 'a' and 'b' into half words and perform the classic long
+    // multiplication with 'a' and 'b' being two-digit numbers.
+
+    //    a      a_hi a_lo
+    //  x b => x b_hi b_lo
+    // ----    -----------
+    //    c         result
+    // We convert 'lo' and 'hi' from 'half_word' to 'word' so multiplication
+    // doesn't overflow.
+    const word a_lo = lo(a);
+    const word b_lo = lo(b);
+    const word a_hi = hi(a);
+    const word b_hi = hi(b);
+    const word step1 = b_lo * a_lo; // no overflow;
+    const word step2 = b_lo * a_hi; // no overflow;
+    const word step3 = b_hi * a_lo; // no overflow;
+    const word step4 = b_hi * a_hi; // no overflow;
+    word lo_digit = step1;
+    word hi_digit = step4;
+    const word zero_carry = 0;
+    word carry;
+    const auto add_with_carry = LIBC_NAMESPACE::add_with_carry<word>;
+    lo_digit = add_with_carry(lo_digit, shiftl(step2), zero_carry, &carry);
+    hi_digit = add_with_carry(hi_digit, shiftr(step2), carry, nullptr);
+    lo_digit = add_with_carry(lo_digit, shiftl(step3), zero_carry, &carry);
+    hi_digit = add_with_carry(hi_digit, shiftr(step3), carry, nullptr);
+    return Double<word>(lo_digit, hi_digit);
+  }
+}
 
-  auto r2 = add_with_carry(prod.lo, hi_lo.lo << HALF_BIT_WIDTH, T(0));
-  prod.lo = r2.sum;
-  prod.hi = add_with_carry(prod.hi, hi_lo.hi, r2.carry).sum;
+// Inplace binary operation with carry propagation. Returns carry.
+template <typename Function, typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word inplace_binop(Function op_with_carry,
+                                         cpp::array<word, N> &dst,
+                                         const cpp::array<word, M> &rhs) {
+  static_assert(N >= M);
+  word carry_out = 0;
+  for (size_t i = 0; i < N; ++i) {
+    const bool has_rhs_value = i < M;
+    const word rhs_value = has_rhs_value ? rhs[i] : 0;
+    const word carry_in = carry_out;
+    dst[i] = op_with_carry(dst[i], rhs_value, carry_in, &carry_out);
+    // stop early when rhs is over and no carry is to be propagated.
+    if (!has_rhs_value && carry_out == 0)
+      break;
+  }
+  return carry_out;
+}
 
-  return prod;
+// Inplace addition. Returns carry.
+template <typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word add_with_carry(cpp::array<word, N> &dst,
+                                          const cpp::array<word, M> &rhs) {
+  return inplace_binop(LIBC_NAMESPACE::add_with_carry<word>, dst, rhs);
 }
 
-template <>
-LIBC_INLINE constexpr NumberPair<uint32_t> full_mul<uint32_t>(uint32_t a,
-                                                              uint32_t b) {
-  uint64_t prod = uint64_t(a) * uint64_t(b);
-  NumberPair<uint32_t> result;
-  result.lo = uint32_t(prod);
-  result.hi = uint32_t(prod >> 32);
-  return result;
+// Inplace subtraction. Returns borrow.
+template <typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word sub_with_borrow(cpp::array<word, N> &dst,
+                                           const cpp::array<word, M> &rhs) {
+  return inplace_binop(LIBC_NAMESPACE::sub_with_borrow<word>, dst, rhs);
 }
 
-#ifdef LIBC_TYPES_HAS_INT128
-template <>
-LIBC_INLINE constexpr NumberPair<uint64_t> full_mul<uint64_t>(uint64_t a,
-                                                              uint64_t b) {
-  __uint128_t prod = __uint128_t(a) * __uint128_t(b);
-  NumberPair<uint64_t> result;
-  result.lo = uint64_t(prod);
-  result.hi = uint64_t(prod >> 64);
-  return result;
+// Inplace multiply-add. Returns carry.
+// i.e., 'dst += b x c'
+template <typename word, size_t N>
+LIBC_INLINE constexpr word mad_with_carry(cpp::array<word, N> &dst, word b,
----------------
michaelrj-google wrote:

`mad` here is ambiguous (unusual abbreviation, easy to mistake for `mod`). Would you be okay with `mul_add` instead?

https://github.com/llvm/llvm-project/pull/86137


More information about the libc-commits mailing list