[libc-commits] [libc] [llvm] [libc] Refactor `BigInt` (PR #86137)
Nick Desaulniers via libc-commits
libc-commits at lists.llvm.org
Wed Mar 27 09:02:58 PDT 2024
================
@@ -17,79 +17,317 @@
#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
#include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128, LIBC_TYPES_HAS_INT64
-#include "src/__support/math_extras.h" // SumCarry, DiffBorrow
+#include "src/__support/math_extras.h" // add_with_carry, sub_with_borrow
#include "src/__support/number_pair.h"
#include <stddef.h> // For size_t
#include <stdint.h>
namespace LIBC_NAMESPACE {
-namespace internal {
-template <typename T> struct half_width;
+namespace multiword {
-template <> struct half_width<uint64_t> : cpp::type_identity<uint32_t> {};
-template <> struct half_width<uint32_t> : cpp::type_identity<uint16_t> {};
+// A type trait mapping unsigned integers to their half-width unsigned
+// counterparts.
+template <typename T> struct half_width;
template <> struct half_width<uint16_t> : cpp::type_identity<uint8_t> {};
+template <> struct half_width<uint32_t> : cpp::type_identity<uint16_t> {};
+#ifdef LIBC_TYPES_HAS_INT64
+template <> struct half_width<uint64_t> : cpp::type_identity<uint32_t> {};
#ifdef LIBC_TYPES_HAS_INT128
template <> struct half_width<__uint128_t> : cpp::type_identity<uint64_t> {};
#endif // LIBC_TYPES_HAS_INT128
-
+#endif // LIBC_TYPES_HAS_INT64
template <typename T> using half_width_t = typename half_width<T>::type;
-template <typename T> constexpr NumberPair<T> full_mul(T a, T b) {
- NumberPair<T> pa = split(a);
- NumberPair<T> pb = split(b);
- NumberPair<T> prod;
+// An array of two elements that can be used in multiword operations.
+template <typename T> struct Double final : cpp::array<T, 2> {
+ using UP = cpp::array<T, 2>;
+ using UP::UP;
+ LIBC_INLINE constexpr Double(T lo, T hi) : UP({lo, hi}) {}
+};
+
+// Converts an unsigned value into a Double<half_width_t<T>>.
+template <typename T> LIBC_INLINE constexpr auto split(T value) {
+ static_assert(cpp::is_unsigned_v<T>);
+ return cpp::bit_cast<Double<half_width_t<T>>>(value);
+}
+
+// The low part of a Double value.
+template <typename T> LIBC_INLINE constexpr T lo(const Double<T> &value) {
+ return value[0];
+}
+// The high part of a Double value.
+template <typename T> LIBC_INLINE constexpr T hi(const Double<T> &value) {
+ return value[1];
+}
+// The low part of an unsigned value.
+template <typename T> LIBC_INLINE constexpr half_width_t<T> lo(T value) {
+ return lo(split(value));
+}
+// The high part of an unsigned value.
+template <typename T> LIBC_INLINE constexpr half_width_t<T> hi(T value) {
+ return hi(split(value));
+}
+
+// Returns 'a' times 'b' in a Double<word>. Cannot overflow by construction.
+template <typename word>
+LIBC_INLINE constexpr Double<word> mul2(word a, word b) {
+ if constexpr (cpp::is_same_v<word, uint8_t>) {
+ return split<uint16_t>(uint16_t(a) * uint16_t(b));
+ } else if constexpr (cpp::is_same_v<word, uint16_t>) {
+ return split<uint32_t>(uint32_t(a) * uint32_t(b));
+ }
+#ifdef LIBC_TYPES_HAS_INT64
+ else if constexpr (cpp::is_same_v<word, uint32_t>) {
+ return split<uint64_t>(uint64_t(a) * uint64_t(b));
+ }
+#endif
+#ifdef LIBC_TYPES_HAS_INT128
+ else if constexpr (cpp::is_same_v<word, uint64_t>) {
+ return split<__uint128_t>(__uint128_t(a) * __uint128_t(b));
+ }
+#endif
+ else {
+ using half_word = half_width_t<word>;
+ const auto shiftl = [](word value) -> word {
+ return value << cpp::numeric_limits<half_word>::digits;
+ };
+ const auto shiftr = [](word value) -> word {
+ return value >> cpp::numeric_limits<half_word>::digits;
+ };
+ // Here we do a one digit multiplication where 'a' and 'b' are of type
+ // word. We split 'a' and 'b' into half words and perform the classic long
+ // multiplication with 'a' and 'b' being two-digit numbers.
+
+ // a a_hi a_lo
+ // x b => x b_hi b_lo
+ // ---- -----------
+ // c result
+ // We convert 'lo' and 'hi' from 'half_word' to 'word' so multiplication
+ // doesn't overflow.
+ const word a_lo = lo(a);
+ const word b_lo = lo(b);
+ const word a_hi = hi(a);
+ const word b_hi = hi(b);
+ const word step1 = b_lo * a_lo; // no overflow;
+ const word step2 = b_lo * a_hi; // no overflow;
+ const word step3 = b_hi * a_lo; // no overflow;
+ const word step4 = b_hi * a_hi; // no overflow;
+ word lo_digit = step1;
+ word hi_digit = step4;
+ const word no_carry = 0;
+ word carry;
+ lo_digit = add_with_carry<word>(lo_digit, shiftl(step2), no_carry, &carry);
+ hi_digit = add_with_carry<word>(hi_digit, shiftr(step2), carry);
+ lo_digit = add_with_carry<word>(lo_digit, shiftl(step3), no_carry, &carry);
+ hi_digit = add_with_carry<word>(hi_digit, shiftr(step3), carry);
+ return Double<word>(lo_digit, hi_digit);
+ }
+}
+
+// In-place 'dst op= rhs' with operation with carry propagation. Returns carry.
+template <typename Function, typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word inplace_binop(Function op_with_carry,
+ cpp::array<word, N> &dst,
+ const cpp::array<word, M> &rhs) {
+ static_assert(N >= M);
+ word carry_out = 0;
+ for (size_t i = 0; i < N; ++i) {
+ const bool has_rhs_value = i < M;
+ const word rhs_value = has_rhs_value ? rhs[i] : 0;
+ const word carry_in = carry_out;
+ dst[i] = op_with_carry(dst[i], rhs_value, carry_in, &carry_out);
+ // stop early when rhs is over and no carry is to be propagated.
+ if (!has_rhs_value && carry_out == 0)
+ break;
+ }
+ return carry_out;
+}
+
+// In-place addition. Returns carry.
+template <typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word add_with_carry(cpp::array<word, N> &dst,
+ const cpp::array<word, M> &rhs) {
+ return inplace_binop(LIBC_NAMESPACE::add_with_carry<word>, dst, rhs);
+}
+
+// In-place subtraction. Returns borrow.
+template <typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word sub_with_borrow(cpp::array<word, N> &dst,
+ const cpp::array<word, M> &rhs) {
+ return inplace_binop(LIBC_NAMESPACE::sub_with_borrow<word>, dst, rhs);
+}
- prod.lo = pa.lo * pb.lo; // exact
- prod.hi = pa.hi * pb.hi; // exact
- NumberPair<T> lo_hi = split(pa.lo * pb.hi); // exact
- NumberPair<T> hi_lo = split(pa.hi * pb.lo); // exact
+// In-place multiply-add. Returns carry.
+// i.e., 'dst += b * c'
+template <typename word, size_t N>
+LIBC_INLINE constexpr word mul_add_with_carry(cpp::array<word, N> &dst, word b,
+ word c) {
+ return add_with_carry(dst, mul2(b, c));
+}
- constexpr size_t HALF_BIT_WIDTH = sizeof(T) * CHAR_BIT / 2;
+// An array of two elements serving as an accumulator during multiword
+// computations.
+template <typename T> struct Accumulator final : cpp::array<T, 2> {
+ using UP = cpp::array<T, 2>;
+ LIBC_INLINE constexpr Accumulator() : UP({0, 0}) {}
+ LIBC_INLINE constexpr T advance(T carry_in) {
+ auto result = UP::front();
+ UP::front() = UP::back();
+ UP::back() = carry_in;
+ return result;
+ }
+ LIBC_INLINE constexpr T sum() const { return UP::front(); }
+ LIBC_INLINE constexpr T carry() const { return UP::back(); }
+};
- auto r1 = add_with_carry(prod.lo, lo_hi.lo << HALF_BIT_WIDTH, T(0));
- prod.lo = r1.sum;
- prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
+// In-place multiplication by a single word. Returns carry.
+template <typename word, size_t N>
+LIBC_INLINE constexpr word scalar_multiply_with_carry(cpp::array<word, N> &dst,
+ word x) {
+ Accumulator<word> acc;
+ for (auto &val : dst) {
+ const word carry = mul_add_with_carry(acc, val, x);
+ val = acc.advance(carry);
+ }
+ return acc.carry();
+}
- auto r2 = add_with_carry(prod.lo, hi_lo.lo << HALF_BIT_WIDTH, T(0));
- prod.lo = r2.sum;
- prod.hi = add_with_carry(prod.hi, hi_lo.hi, r2.carry).sum;
+// Multiplication of 'lhs' by 'rhs' into 'dst'. Returns carry.
+// This function is safe to use for signed numbers.
+// https://stackoverflow.com/a/20793834
+// https://pages.cs.wisc.edu/%7Emarkhill/cs354/Fall2008/beyond354/int.mult.html
+template <typename word, size_t O, size_t M, size_t N>
+LIBC_INLINE constexpr word multiply_with_carry(cpp::array<word, O> &dst,
+ const cpp::array<word, M> &lhs,
+ const cpp::array<word, N> &rhs) {
+ static_assert(O >= M + N);
+ Accumulator<word> acc;
+ for (size_t i = 0; i < O; ++i) {
+ const size_t lower_idx = i < N ? 0 : i - N + 1;
+ const size_t upper_idx = i < M ? i : M - 1;
+ word carry = 0;
+ for (size_t j = lower_idx; j <= upper_idx; ++j)
+ carry += mul_add_with_carry(acc, lhs[j], rhs[i - j]);
+ dst[i] = acc.advance(carry);
+ }
+ return acc.carry();
+}
- return prod;
+template <typename word, size_t N>
+LIBC_INLINE constexpr void quick_mul_hi(cpp::array<word, N> &dst,
+ const cpp::array<word, N> &lhs,
+ const cpp::array<word, N> &rhs) {
+ Accumulator<word> acc;
+ word carry = 0;
+ // First round of accumulation for those at N - 1 in the full product.
+ for (size_t i = 0; i < N; ++i)
+ carry += mul_add_with_carry(acc, lhs[i], rhs[N - 1 - i]);
+ for (size_t i = N; i < 2 * N - 1; ++i) {
+ acc.advance(carry);
+ carry = 0;
+ for (size_t j = i - N + 1; j < N; ++j)
+ carry += mul_add_with_carry(acc, lhs[j], rhs[i - j]);
+ dst[i - N] = acc.sum();
+ }
+ dst.back() = acc.carry();
}
-template <>
-LIBC_INLINE constexpr NumberPair<uint32_t> full_mul<uint32_t>(uint32_t a,
- uint32_t b) {
- uint64_t prod = uint64_t(a) * uint64_t(b);
- NumberPair<uint32_t> result;
- result.lo = uint32_t(prod);
- result.hi = uint32_t(prod >> 32);
- return result;
+template <typename word, size_t N>
+LIBC_INLINE constexpr bool is_negative(cpp::array<word, N> &array) {
+ using signed_word = cpp::make_signed_t<word>;
+ return cpp::bit_cast<signed_word>(array.back()) < 0;
}
+// An enum for the shift function below.
+enum Direction { LEFT, RIGHT };
+
+// A bitwise shift on an array of elements.
+// TODO: Make the result UB when 'offset' is greater or equal to the number of
+// bits in 'array'. This will allow for better code performance.
+template <Direction direction, bool is_signed, typename word, size_t N>
+LIBC_INLINE constexpr cpp::array<word, N> shift(cpp::array<word, N> array,
+ size_t offset) {
+ constexpr size_t WORD_BITS = cpp::numeric_limits<word>::digits;
+ constexpr size_t TOTAL_BITS = N * WORD_BITS;
+ if (offset == 0)
+ return array;
+ if (offset >= TOTAL_BITS)
+ return {};
#ifdef LIBC_TYPES_HAS_INT128
-template <>
-LIBC_INLINE constexpr NumberPair<uint64_t> full_mul<uint64_t>(uint64_t a,
- uint64_t b) {
- __uint128_t prod = __uint128_t(a) * __uint128_t(b);
- NumberPair<uint64_t> result;
- result.lo = uint64_t(prod);
- result.hi = uint64_t(prod >> 64);
- return result;
+ if constexpr (TOTAL_BITS == 128) {
+ using type = cpp::conditional_t<is_signed, __int128_t, __uint128_t>;
+ auto tmp = cpp::bit_cast<type>(array);
+ if constexpr (direction == LEFT)
+ tmp <<= offset;
+ else if constexpr (direction == RIGHT)
+ tmp >>= offset;
+ return cpp::bit_cast<cpp::array<word, N>>(tmp);
----------------
nickdesaulniers wrote:
If direction can only be LEFT or RIGHT, consider adding a static assertion, then you can remove the `if constexpr (direction == RIGHT)` part.
https://github.com/llvm/llvm-project/pull/86137
More information about the libc-commits
mailing list