[libc-commits] [libc] [libc] Add bigint casting between word types (PR #111914)
Michael Jones via libc-commits
libc-commits at lists.llvm.org
Fri Oct 11 15:42:28 PDT 2024
https://github.com/michaelrj-google updated https://github.com/llvm/llvm-project/pull/111914
>From 432b59ab4b6e74494c7ad934c5fe7f76fd411a74 Mon Sep 17 00:00:00 2001
From: Michael Jones <michaelrj at google.com>
Date: Thu, 10 Oct 2024 13:38:07 -0700
Subject: [PATCH 1/2] [libc] Add bigint casting between word types
Previously you could cast between bigints with different numbers of
bits, but only if they had the same underlying type. This patch adds the
ability to cast between bigints with different underlying types, which
is needed for #110894
---
libc/src/__support/big_int.h | 95 +++++++++++++--
libc/test/src/__support/big_int_test.cpp | 142 ++++++++++++++++++++++-
2 files changed, 225 insertions(+), 12 deletions(-)
diff --git a/libc/src/__support/big_int.h b/libc/src/__support/big_int.h
index 681782d57319e5..9ab50391c7469d 100644
--- a/libc/src/__support/big_int.h
+++ b/libc/src/__support/big_int.h
@@ -14,7 +14,7 @@
#include "src/__support/CPP/limits.h"
#include "src/__support/CPP/optional.h"
#include "src/__support/CPP/type_traits.h"
-#include "src/__support/macros/attributes.h" // LIBC_INLINE
+#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
#include "src/__support/macros/properties/compiler.h" // LIBC_COMPILER_IS_CLANG
@@ -361,17 +361,90 @@ struct BigInt {
LIBC_INLINE constexpr BigInt(const BigInt &other) = default;
- template <size_t OtherBits, bool OtherSigned>
+ template <size_t OtherBits, bool OtherSigned, typename OtherWordType>
LIBC_INLINE constexpr BigInt(
- const BigInt<OtherBits, OtherSigned, WordType> &other) {
- if (OtherBits >= Bits) { // truncate
- for (size_t i = 0; i < WORD_COUNT; ++i)
- val[i] = other[i];
- } else { // zero or sign extend
- size_t i = 0;
- for (; i < OtherBits / WORD_SIZE; ++i)
- val[i] = other[i];
- extend(i, Signed && other.is_neg());
+ const BigInt<OtherBits, OtherSigned, OtherWordType> &other) {
+ using BigIntOther = BigInt<OtherBits, OtherSigned, OtherWordType>;
+ const bool should_sign_extend = Signed && other.is_neg();
+
+ if constexpr (BigIntOther::WORD_SIZE < WORD_SIZE) {
+ // OtherWordType is smaller
+ constexpr size_t WORD_SIZE_RATIO = WORD_SIZE / BigIntOther::WORD_SIZE;
+ static_assert(
+ (WORD_SIZE % BigIntOther::WORD_SIZE) == 0 &&
+ "Word types must be multiples of each other for correct conversion.");
+ if (OtherBits >= Bits) { // truncate
+ // for each big word
+ for (size_t i = 0; i < WORD_COUNT; ++i) {
+ WordType cur_word = 0;
+ // combine WORD_SIZE_RATIO small words into a big word
+ for (size_t j = 0; j < WORD_SIZE_RATIO; ++j)
+ cur_word |= static_cast<WordType>(other[(i * WORD_SIZE_RATIO) + j])
+ << (BigIntOther::WORD_SIZE * j);
+
+ val[i] = cur_word;
+ }
+ } else { // zero or sign extend
+ size_t i = 0;
+ WordType cur_word = 0;
+ // for each small word
+ for (; i < BigIntOther::WORD_COUNT; ++i) {
+ // combine WORD_SIZE_RATIO small words into a big word
+ cur_word |= static_cast<WordType>(other[i])
+ << (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO));
+ // if we've completed a big word, copy it into place and reset
+ if ((i % WORD_SIZE_RATIO) == WORD_SIZE_RATIO - 1) {
+ val[i / WORD_SIZE_RATIO] = cur_word;
+ cur_word = 0;
+ }
+ }
+ // Pretend there are extra words of the correct sign extension as needed
+
+ const WordType extension_bits =
+ should_sign_extend ? cpp::numeric_limits<WordType>::max()
+ : cpp::numeric_limits<WordType>::min();
+ if ((i % WORD_SIZE_RATIO) != 0) {
+ cur_word |= static_cast<WordType>(extension_bits)
+ << (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO));
+ }
+ // Copy the last word into place.
+ val[(i / WORD_SIZE_RATIO)] = cur_word;
+ extend((i / WORD_SIZE_RATIO) + 1, should_sign_extend);
+ }
+ } else if constexpr (BigIntOther::WORD_SIZE == WORD_SIZE) {
+ if (OtherBits >= Bits) { // truncate
+ for (size_t i = 0; i < WORD_COUNT; ++i)
+ val[i] = other[i];
+ } else { // zero or sign extend
+ size_t i = 0;
+ for (; i < BigIntOther::WORD_COUNT; ++i)
+ val[i] = other[i];
+ extend(i, should_sign_extend);
+ }
+ } else {
+ // OtherWordType is bigger.
+ constexpr size_t WORD_SIZE_RATIO = BigIntOther::WORD_SIZE / WORD_SIZE;
+ static_assert(
+ (BigIntOther::WORD_SIZE % WORD_SIZE) == 0 &&
+ "Word types must be multiples of each other for correct conversion.");
+ if (OtherBits >= Bits) { // truncate
+ // for each small word
+ for (size_t i = 0; i < WORD_COUNT; ++i) {
+ // split each big word into WORD_SIZE_RATIO small words
+ val[i] = static_cast<WordType>(other[i / WORD_SIZE_RATIO] >>
+ ((i % WORD_SIZE_RATIO) * WORD_SIZE));
+ }
+ } else { // zero or sign extend
+ size_t i = 0;
+ // for each big word
+ for (; i < BigIntOther::WORD_COUNT; ++i) {
+ // split each big word into WORD_SIZE_RATIO small words
+ for (size_t j = 0; j < WORD_SIZE_RATIO; ++j)
+ val[(i * WORD_SIZE_RATIO) + j] =
+ static_cast<WordType>(other[i] >> (j * WORD_SIZE));
+ }
+ extend(i * WORD_SIZE_RATIO, should_sign_extend);
+ }
}
}
diff --git a/libc/test/src/__support/big_int_test.cpp b/libc/test/src/__support/big_int_test.cpp
index a1ce69baaae290..471ca72a8f6e0c 100644
--- a/libc/test/src/__support/big_int_test.cpp
+++ b/libc/test/src/__support/big_int_test.cpp
@@ -8,7 +8,7 @@
#include "src/__support/CPP/optional.h"
#include "src/__support/big_int.h"
-#include "src/__support/integer_literals.h" // parse_unsigned_bigint
+#include "src/__support/integer_literals.h" // parse_unsigned_bigint
#include "src/__support/macros/config.h"
#include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128
@@ -208,6 +208,7 @@ TYPED_TEST(LlvmLibcUIntClassTest, CountBits, Types) {
}
using LL_UInt16 = UInt<16>;
+using LL_UInt32 = UInt<32>;
using LL_UInt64 = UInt<64>;
// We want to test UInt<128> explicitly. So, for
// convenience, we use a sugar which does not conflict with the UInt128 type
@@ -927,4 +928,143 @@ TEST(LlvmLibcUIntClassTest, OtherWordTypeTests) {
ASSERT_EQ(static_cast<int>(a >> 64), 1);
}
+TEST(LlvmLibcUIntClassTest, OtherWordTypeCastTests) {
+ using LL_UInt96 = BigInt<96, false, uint32_t>;
+
+ LL_UInt96 a({123, 456, 789});
+
+ ASSERT_EQ(static_cast<int>(a), 123);
+ ASSERT_EQ(static_cast<int>(a >> 32), 456);
+ ASSERT_EQ(static_cast<int>(a >> 64), 789);
+
+ // Bigger word with more bits to smaller word with less bits.
+ LL_UInt128 b(a);
+
+ ASSERT_EQ(static_cast<int>(b), 123);
+ ASSERT_EQ(static_cast<int>(b >> 32), 456);
+ ASSERT_EQ(static_cast<int>(b >> 64), 789);
+ ASSERT_EQ(static_cast<int>(b >> 96), 0);
+
+ b = (b << 32) + 987;
+
+ ASSERT_EQ(static_cast<int>(b), 987);
+ ASSERT_EQ(static_cast<int>(b >> 32), 123);
+ ASSERT_EQ(static_cast<int>(b >> 64), 456);
+ ASSERT_EQ(static_cast<int>(b >> 96), 789);
+
+ // Smaller word with less bits to bigger word with more bits.
+ LL_UInt96 c(b);
+
+ ASSERT_EQ(static_cast<int>(c), 987);
+ ASSERT_EQ(static_cast<int>(c >> 32), 123);
+ ASSERT_EQ(static_cast<int>(c >> 64), 456);
+
+ // Smaller word with more bits to bigger word with less bits
+ LL_UInt64 d(c);
+
+ ASSERT_EQ(static_cast<int>(d), 987);
+ ASSERT_EQ(static_cast<int>(d >> 32), 123);
+
+ // Bigger word with less bits to smaller word with more bits
+
+ LL_UInt96 e(d);
+
+ ASSERT_EQ(static_cast<int>(e), 987);
+ ASSERT_EQ(static_cast<int>(e >> 32), 123);
+
+ e = (e << 32) + 654;
+
+ ASSERT_EQ(static_cast<int>(e), 654);
+ ASSERT_EQ(static_cast<int>(e >> 32), 987);
+ ASSERT_EQ(static_cast<int>(e >> 64), 123);
+}
+
+TEST(LlvmLibcUIntClassTest, SignedOtherWordTypeCastTests) {
+ using LL_Int64 = BigInt<64, true, uint64_t>;
+ using LL_Int96 = BigInt<96, true, uint32_t>;
+
+ LL_Int64 zero_64(0);
+ LL_Int96 zero_96(0);
+ LL_Int192 zero_192(0);
+
+ LL_Int96 plus_a({0x1234, 0x5678, 0x9ABC});
+
+ ASSERT_EQ(static_cast<int>(plus_a), 0x1234);
+ ASSERT_EQ(static_cast<int>(plus_a >> 32), 0x5678);
+ ASSERT_EQ(static_cast<int>(plus_a >> 64), 0x9ABC);
+
+ LL_Int96 minus_a(-plus_a);
+
+ // The reason that the numbers are inverted and not negated is that we're
+ // using two's complement. To negate a two's complement number you flip the
+ // bits and add 1, so minus_a is {~0x1234, ~0x5678, ~0x9ABC} + {1,0,0}.
+ ASSERT_EQ(static_cast<int>(minus_a), (~0x1234) + 1);
+ ASSERT_EQ(static_cast<int>(minus_a >> 32), ~0x5678);
+ ASSERT_EQ(static_cast<int>(minus_a >> 64), ~0x9ABC);
+
+ ASSERT_TRUE(plus_a + minus_a == zero_96);
+
+ // 192 so there's an extra block to get sign extended to
+ LL_Int192 bigger_plus_a(plus_a);
+
+ ASSERT_EQ(static_cast<int>(bigger_plus_a), 0x1234);
+ ASSERT_EQ(static_cast<int>(bigger_plus_a >> 32), 0x5678);
+ ASSERT_EQ(static_cast<int>(bigger_plus_a >> 64), 0x9ABC);
+ ASSERT_EQ(static_cast<int>(bigger_plus_a >> 96), 0);
+ ASSERT_EQ(static_cast<int>(bigger_plus_a >> 128), 0);
+ ASSERT_EQ(static_cast<int>(bigger_plus_a >> 160), 0);
+
+ LL_Int192 bigger_minus_a(minus_a);
+
+ ASSERT_EQ(static_cast<int>(bigger_minus_a), (~0x1234) + 1);
+ ASSERT_EQ(static_cast<int>(bigger_minus_a >> 32), ~0x5678);
+ ASSERT_EQ(static_cast<int>(bigger_minus_a >> 64), ~0x9ABC);
+ ASSERT_EQ(static_cast<int>(bigger_minus_a >> 96), ~0);
+ ASSERT_EQ(static_cast<int>(bigger_minus_a >> 128), ~0);
+ ASSERT_EQ(static_cast<int>(bigger_minus_a >> 160), ~0);
+
+ ASSERT_TRUE(bigger_plus_a + bigger_minus_a == zero_192);
+
+ LL_Int64 smaller_plus_a(plus_a);
+
+ ASSERT_EQ(static_cast<int>(smaller_plus_a), 0x1234);
+ ASSERT_EQ(static_cast<int>(smaller_plus_a >> 32), 0x5678);
+
+ LL_Int64 smaller_minus_a(minus_a);
+
+ ASSERT_EQ(static_cast<int>(smaller_minus_a), (~0x1234) + 1);
+ ASSERT_EQ(static_cast<int>(smaller_minus_a >> 32), ~0x5678);
+
+ ASSERT_TRUE(smaller_plus_a + smaller_minus_a == zero_64);
+
+ // Also try going from bigger word size to smaller word size
+ LL_Int96 smaller_back_plus_a(smaller_plus_a);
+
+ ASSERT_EQ(static_cast<int>(smaller_back_plus_a), 0x1234);
+ ASSERT_EQ(static_cast<int>(smaller_back_plus_a >> 32), 0x5678);
+ ASSERT_EQ(static_cast<int>(smaller_back_plus_a >> 64), 0);
+
+ LL_Int96 smaller_back_minus_a(smaller_minus_a);
+
+ ASSERT_EQ(static_cast<int>(smaller_back_minus_a), (~0x1234) + 1);
+ ASSERT_EQ(static_cast<int>(smaller_back_minus_a >> 32), ~0x5678);
+ ASSERT_EQ(static_cast<int>(smaller_back_minus_a >> 64), ~0);
+
+ ASSERT_TRUE(smaller_back_plus_a + smaller_back_minus_a == zero_96);
+
+ LL_Int96 bigger_back_plus_a(bigger_plus_a);
+
+ ASSERT_EQ(static_cast<int>(bigger_back_plus_a), 0x1234);
+ ASSERT_EQ(static_cast<int>(bigger_back_plus_a >> 32), 0x5678);
+ ASSERT_EQ(static_cast<int>(bigger_back_plus_a >> 64), 0x9ABC);
+
+ LL_Int96 bigger_back_minus_a(bigger_minus_a);
+
+ ASSERT_EQ(static_cast<int>(bigger_back_minus_a), (~0x1234) + 1);
+ ASSERT_EQ(static_cast<int>(bigger_back_minus_a >> 32), ~0x5678);
+ ASSERT_EQ(static_cast<int>(bigger_back_minus_a >> 64), ~0x9ABC);
+
+ ASSERT_TRUE(bigger_back_plus_a + bigger_back_minus_a == zero_96);
+}
+
} // namespace LIBC_NAMESPACE_DECL
>From 482c4d1d642e8f016eb2020d45ba5b12649fd409 Mon Sep 17 00:00:00 2001
From: Michael Jones <michaelrj at google.com>
Date: Fri, 11 Oct 2024 15:42:10 -0700
Subject: [PATCH 2/2] add constexpr to some ifs
---
libc/src/__support/big_int.h | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/libc/src/__support/big_int.h b/libc/src/__support/big_int.h
index 9ab50391c7469d..a03e203769e293 100644
--- a/libc/src/__support/big_int.h
+++ b/libc/src/__support/big_int.h
@@ -373,7 +373,7 @@ struct BigInt {
static_assert(
(WORD_SIZE % BigIntOther::WORD_SIZE) == 0 &&
"Word types must be multiples of each other for correct conversion.");
- if (OtherBits >= Bits) { // truncate
+ if constexpr (OtherBits >= Bits) { // truncate
// for each big word
for (size_t i = 0; i < WORD_COUNT; ++i) {
WordType cur_word = 0;
@@ -412,7 +412,7 @@ struct BigInt {
extend((i / WORD_SIZE_RATIO) + 1, should_sign_extend);
}
} else if constexpr (BigIntOther::WORD_SIZE == WORD_SIZE) {
- if (OtherBits >= Bits) { // truncate
+ if constexpr (OtherBits >= Bits) { // truncate
for (size_t i = 0; i < WORD_COUNT; ++i)
val[i] = other[i];
} else { // zero or sign extend
@@ -427,7 +427,7 @@ struct BigInt {
static_assert(
(BigIntOther::WORD_SIZE % WORD_SIZE) == 0 &&
"Word types must be multiples of each other for correct conversion.");
- if (OtherBits >= Bits) { // truncate
+ if constexpr (OtherBits >= Bits) { // truncate
// for each small word
for (size_t i = 0; i < WORD_COUNT; ++i) {
// split each big word into WORD_SIZE_RATIO small words
More information about the libc-commits
mailing list