[llvm] Refactor Bitset to Be More Constexpr-Usable and Add More Member Functions (PR #172062)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 12 10:24:49 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt
Author: Jiachen Yuan (JiachenYuan)
<details>
<summary>Changes</summary>
This patch refactors some essential `Bitset` member functions to be `constexpr` and adds more useful member functions. Unit tests have been added to `BitsetTest.cpp` to cover both runtime and `consteval` context correctness.
---
Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172062.diff
2 Files Affected:
- (modified) llvm/include/llvm/ADT/Bitset.h (+109-15)
- (modified) llvm/unittests/ADT/BitsetTest.cpp (+543)
``````````diff
diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 1d4cbf8306230..f963e10646dce 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -16,7 +16,7 @@
#ifndef LLVM_ADT_BITSET_H
#define LLVM_ADT_BITSET_H
-#include <llvm/ADT/STLExtras.h>
+#include "llvm/ADT/bit.h"
#include <array>
#include <climits>
#include <cstdint>
@@ -31,6 +31,10 @@ class Bitset {
using BitWord = uintptr_t;
static constexpr unsigned BitwordBits = sizeof(BitWord) * CHAR_BIT;
+ static constexpr unsigned RemainderNumBits = NumBits % BitwordBits;
+ static constexpr BitWord RemainderMask =
+ RemainderNumBits == 0 ? ~BitWord(0)
+ : ((BitWord(1) << RemainderNumBits) - 1);
static_assert(BitwordBits == 64 || BitwordBits == 32,
"Unsupported word size");
@@ -41,6 +45,11 @@ class Bitset {
using StorageType = std::array<BitWord, NumWords>;
StorageType Bits{};
+ constexpr void maskLastWord() {
+ if constexpr (RemainderNumBits != 0)
+ Bits[NumWords - 1] &= RemainderMask;
+ }
+
protected:
constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
@@ -52,12 +61,13 @@ class Bitset {
uint64_t Elt = B[I];
// On a 32-bit system the storage type will be 32-bit, so we may only
// need half of a uint64_t.
- for (size_t offset = 0; offset != 2 && BitsToAssign; ++offset) {
- Bits[2 * I + offset] = static_cast<uint32_t>(Elt >> (32 * offset));
+ for (size_t Offset = 0; Offset != 2 && BitsToAssign; ++Offset) {
+ Bits[2 * I + Offset] = static_cast<uint32_t>(Elt >> (32 * Offset));
BitsToAssign = BitsToAssign >= 32 ? BitsToAssign - 32 : 0;
}
}
}
+ maskLastWord();
}
public:
@@ -67,8 +77,11 @@ class Bitset {
set(I);
}
- Bitset &set() {
- llvm::fill(Bits, -BitWord(0));
+ constexpr Bitset &set() {
+ constexpr const BitWord AllOnes = ~BitWord(0);
+ for (BitWord &B : Bits)
+ B = AllOnes;
+ maskLastWord();
return *this;
}
@@ -96,14 +109,28 @@ class Bitset {
constexpr size_t size() const { return NumBits; }
- bool any() const {
- return llvm::any_of(Bits, [](BitWord I) { return I != 0; });
+ constexpr bool any() const {
+ for (unsigned I = 0; I < NumWords - 1; ++I)
+ if (Bits[I] != 0)
+ return true;
+ return (Bits[NumWords - 1] & RemainderMask) != 0;
}
- bool none() const { return !any(); }
- size_t count() const {
+
+ constexpr bool none() const { return !any(); }
+
+ constexpr bool all() const {
+ constexpr const BitWord AllOnes = ~BitWord(0);
+ for (unsigned I = 0; I < NumWords - 1; ++I)
+ if (Bits[I] != AllOnes)
+ return false;
+ return (Bits[NumWords - 1] & RemainderMask) == RemainderMask;
+ }
+
+ constexpr size_t count() const {
size_t Count = 0;
- for (auto B : Bits)
- Count += llvm::popcount(B);
+ for (unsigned I = 0; I < NumWords - 1; ++I)
+ Count += popcount(Bits[I]);
+ Count += popcount(Bits[NumWords - 1] & RemainderMask);
return Count;
}
@@ -146,16 +173,21 @@ class Bitset {
Bitset Result = *this;
for (auto &B : Result.Bits)
B = ~B;
+ Result.maskLastWord();
return Result;
}
- bool operator==(const Bitset &RHS) const {
- return std::equal(std::begin(Bits), std::end(Bits), std::begin(RHS.Bits));
+ constexpr bool operator==(const Bitset &RHS) const {
+ for (unsigned I = 0; I < NumWords - 1; ++I)
+ if (Bits[I] != RHS.Bits[I])
+ return false;
+ return (Bits[NumWords - 1] & RemainderMask) ==
+ (RHS.Bits[NumWords - 1] & RemainderMask);
}
- bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
+ constexpr bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
- bool operator < (const Bitset &Other) const {
+ constexpr bool operator<(const Bitset &Other) const {
for (unsigned I = 0, E = size(); I != E; ++I) {
bool LHS = test(I), RHS = Other.test(I);
if (LHS != RHS)
@@ -163,6 +195,68 @@ class Bitset {
}
return false;
}
+
+ constexpr Bitset &operator<<=(unsigned N) {
+ if (N == 0)
+ return *this;
+ if (N >= NumBits) {
+ return *this = Bitset();
+ }
+ const unsigned WordShift = N / BitwordBits;
+ const unsigned BitShift = N % BitwordBits;
+ if (BitShift == 0) {
+ for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
+ Bits[I] = Bits[I - WordShift];
+ } else {
+ const unsigned CarryShift = BitwordBits - BitShift;
+ for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
+ Bits[I] = (Bits[I - WordShift] << BitShift) |
+ (Bits[I - WordShift - 1] >> CarryShift);
+ }
+ Bits[WordShift] = Bits[0] << BitShift;
+ }
+ for (unsigned I = 0; I < WordShift; ++I)
+ Bits[I] = 0;
+ maskLastWord();
+ return *this;
+ }
+
+ constexpr Bitset operator<<(unsigned N) const {
+ Bitset Result(*this);
+ Result <<= N;
+ return Result;
+ }
+
+ constexpr Bitset &operator>>=(unsigned N) {
+ if (N == 0)
+ return *this;
+ if (N >= NumBits) {
+ return *this = Bitset();
+ }
+ const unsigned WordShift = N / BitwordBits;
+ const unsigned BitShift = N % BitwordBits;
+ if (BitShift == 0) {
+ for (unsigned I = 0; I < NumWords - WordShift; ++I)
+ Bits[I] = Bits[I + WordShift];
+ } else {
+ const unsigned CarryShift = BitwordBits - BitShift;
+ for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
+ Bits[I] = (Bits[I + WordShift] >> BitShift) |
+ (Bits[I + WordShift + 1] << CarryShift);
+ }
+ Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
+ }
+ for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
+ Bits[I] = 0;
+ maskLastWord();
+ return *this;
+ }
+
+ constexpr Bitset operator>>(unsigned N) const {
+ Bitset Result(*this);
+ Result >>= N;
+ return Result;
+ }
};
} // end namespace llvm
diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 0ecd213d6a781..c0013ce385e7e 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -68,4 +68,547 @@ TEST(BitsetTest, Construction) {
EXPECT_TRUE(Test33.verifyValue(TestSingleVal));
Test33.verifyStorageSize(1, 2);
}
+
+TEST(BitsetTest, SetAndQuery) {
+ // Test set() with all bits.
+ Bitset<64> A;
+ A.set();
+ EXPECT_TRUE(A.all());
+ EXPECT_TRUE(A.any());
+ EXPECT_FALSE(A.none());
+
+ static_assert(Bitset<64>().set().all());
+ static_assert(Bitset<33>().set().all());
+
+ // Test set() with single bit.
+ Bitset<64> B;
+ B.set(10);
+ B.set(20);
+ EXPECT_TRUE(B.test(10));
+ EXPECT_TRUE(B.test(20));
+ EXPECT_FALSE(B.test(15));
+
+ static_assert(Bitset<64>().set(10).test(10));
+ static_assert(Bitset<64>().set(0).set(63).test(0) &&
+ Bitset<64>().set(0).set(63).test(63));
+ static_assert(Bitset<33>().set(32).test(32));
+ static_assert(Bitset<128>().set(64).set(127).test(64) &&
+ Bitset<128>().set(64).set(127).test(127));
+
+ // Test reset() with single bit.
+ Bitset<64> C({10, 20, 30});
+ C.reset(20);
+ EXPECT_TRUE(C.test(10));
+ EXPECT_FALSE(C.test(20));
+ EXPECT_TRUE(C.test(30));
+
+ static_assert(!Bitset<64>({10, 20}).reset(10).test(10));
+ static_assert(Bitset<64>({10, 20}).reset(10).test(20));
+ static_assert(!Bitset<96>({31, 32, 63}).reset(32).test(32));
+ static_assert(Bitset<33>({0, 32}).reset(0).test(32));
+
+ // Test flip() with single bit.
+ Bitset<64> D({10, 20});
+ D.flip(10);
+ D.flip(30);
+ EXPECT_FALSE(D.test(10));
+ EXPECT_TRUE(D.test(20));
+ EXPECT_TRUE(D.test(30));
+
+ static_assert(!Bitset<64>({10, 20}).flip(10).test(10));
+ static_assert(Bitset<64>({10, 20}).flip(30).test(30));
+ static_assert(Bitset<100>({50, 99}).flip(50).test(99) &&
+ !Bitset<100>({50, 99}).flip(50).test(50));
+ static_assert(Bitset<33>().flip(32).test(32));
+
+ // Test operator[].
+ Bitset<64> E({5, 15, 25});
+ EXPECT_TRUE(E[5]);
+ EXPECT_FALSE(E[10]);
+ EXPECT_TRUE(E[15]);
+
+ static_assert(Bitset<64>({10, 20})[10]);
+ static_assert(!Bitset<64>({10, 20})[15]);
+ static_assert(Bitset<128>({127})[127]);
+ static_assert(Bitset<96>({63, 64})[63] && Bitset<96>({63, 64})[64]);
+
+ // Test size().
+ EXPECT_EQ(A.size(), 64u);
+ Bitset<33> F;
+ EXPECT_EQ(F.size(), 33u);
+
+ static_assert(Bitset<64>().size() == 64);
+ static_assert(Bitset<128>().size() == 128);
+ static_assert(Bitset<33>().size() == 33);
+
+ // Test any() and none().
+ static_assert(!Bitset<64>().any());
+ static_assert(Bitset<64>().none());
+ static_assert(Bitset<64>({10}).any());
+ static_assert(!Bitset<64>({10}).none());
+}
+
+TEST(BitsetTest, ComparisonOperators) {
+ // Test operator==.
+ Bitset<64> A({10, 20, 30});
+ Bitset<64> B({10, 20, 30});
+ Bitset<64> C({10, 20, 31});
+ EXPECT_TRUE(A == B);
+ EXPECT_FALSE(A == C);
+
+ static_assert(Bitset<64>({10, 20}) == Bitset<64>({10, 20}));
+ static_assert(Bitset<64>({10, 20}) != Bitset<64>({10, 21}));
+
+ // Test operator< (lexicographic comparison, bit 0 is least significant).
+ static_assert(Bitset<64>({5, 11}) <
+ Bitset<64>({5, 10})); // At bit 10: A=0, B=1.
+ static_assert(!(Bitset<64>({5, 10}) < Bitset<64>({5, 10})));
+}
+
+TEST(BitsetTest, BitwiseNot) {
+ // Test operator~.
+ Bitset<64> A;
+ A.set();
+ Bitset<64> B = ~A;
+ EXPECT_TRUE(B.none());
+
+ static_assert((~Bitset<64>()).all());
+ static_assert((~Bitset<64>().set()).none());
+ static_assert((~Bitset<33>().set()).none());
+}
+
+TEST(BitsetTest, BitwiseOperators) {
+ // Test operator&.
+ Bitset<64> A({10, 20, 30});
+ Bitset<64> B({20, 30, 40});
+ Bitset<64> Result1 = A & B;
+ EXPECT_FALSE(Result1.test(10));
+ EXPECT_TRUE(Result1.test(20));
+ EXPECT_TRUE(Result1.test(30));
+ EXPECT_FALSE(Result1.test(40));
+ EXPECT_EQ(Result1.count(), 2u);
+
+ static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(20));
+ static_assert(!(Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(10));
+ static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).count() == 1);
+ static_assert(
+ (Bitset<96>({31, 32, 63, 64}) & Bitset<96>({32, 64, 95})).count() == 2);
+ static_assert((Bitset<33>({0, 32}) & Bitset<33>({32})).test(32));
+
+ // Test operator&=.
+ Bitset<64> C({10, 20, 30});
+ C &= Bitset<64>({20, 30, 40});
+ EXPECT_FALSE(C.test(10));
+ EXPECT_TRUE(C.test(20));
+ EXPECT_TRUE(C.test(30));
+ EXPECT_FALSE(C.test(40));
+
+ constexpr Bitset<64> TestAnd = [] {
+ Bitset<64> X({10, 20, 30});
+ X &= Bitset<64>({20, 30, 40});
+ return X;
+ }();
+ static_assert(TestAnd.test(20) && TestAnd.test(30) && !TestAnd.test(10));
+
+ constexpr Bitset<100> TestAnd100 = [] {
+ Bitset<100> X({10, 50, 99});
+ X &= Bitset<100>({50, 99});
+ return X;
+ }();
+ static_assert(TestAnd100.count() == 2 && TestAnd100.test(50) &&
+ TestAnd100.test(99));
+
+ // Test operator|.
+ Bitset<64> D({10, 20});
+ Bitset<64> E({20, 30});
+ Bitset<64> Result2 = D | E;
+ EXPECT_TRUE(Result2.test(10));
+ EXPECT_TRUE(Result2.test(20));
+ EXPECT_TRUE(Result2.test(30));
+ EXPECT_EQ(Result2.count(), 3u);
+
+ static_assert((Bitset<64>({10}) | Bitset<64>({20})).count() == 2);
+ static_assert((Bitset<128>({0, 64, 127}) | Bitset<128>({64, 100})).count() ==
+ 4);
+ static_assert((Bitset<33>({0, 16}) | Bitset<33>({16, 32})).count() == 3);
+
+ // Test operator|=.
+ Bitset<64> F({10, 20});
+ F |= Bitset<64>({20, 30});
+ EXPECT_TRUE(F.test(10));
+ EXPECT_TRUE(F.test(20));
+ EXPECT_TRUE(F.test(30));
+
+ constexpr Bitset<64> TestOr = [] {
+ Bitset<64> X({10});
+ X |= Bitset<64>({20});
+ return X;
+ }();
+ static_assert(TestOr.test(10) && TestOr.test(20));
+
+ constexpr Bitset<96> TestOr96 = [] {
+ Bitset<96> X({31, 63});
+ X |= Bitset<96>({32, 64});
+ return X;
+ }();
+ static_assert(TestOr96.count() == 4);
+
+ // Test operator^.
+ Bitset<64> G({10, 20, 30});
+ Bitset<64> H({20, 30, 40});
+ Bitset<64> Result3 = G ^ H;
+ EXPECT_TRUE(Result3.test(10));
+ EXPECT_FALSE(Result3.test(20));
+ EXPECT_FALSE(Result3.test(30));
+ EXPECT_TRUE(Result3.test(40));
+ EXPECT_EQ(Result3.count(), 2u);
+
+ static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(10));
+ static_assert(!(Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(20));
+ static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(30));
+ static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).count() == 2);
+ static_assert((Bitset<100>({0, 50, 99}) ^ Bitset<100>({50})).count() == 2);
+ static_assert((Bitset<33>({0, 32}) ^ Bitset<33>({0, 16})).count() == 2);
+
+ // Test operator^=.
+ Bitset<64> I({10, 20, 30});
+ I ^= Bitset<64>({20, 30, 40});
+ EXPECT_TRUE(I.test(10));
+ EXPECT_FALSE(I.test(20));
+ EXPECT_FALSE(I.test(30));
+ EXPECT_TRUE(I.test(40));
+
+ constexpr Bitset<64> TestXor = [] {
+ Bitset<64> X({10, 20});
+ X ^= Bitset<64>({20, 30});
+ return X;
+ }();
+ static_assert(TestXor.test(10) && !TestXor.test(20) && TestXor.test(30));
+
+ constexpr Bitset<128> TestXor128 = [] {
+ Bitset<128> X({0, 64, 127});
+ X ^= Bitset<128>({64});
+ return X;
+ }();
+ static_assert(TestXor128.count() == 2 && TestXor128.test(0) &&
+ TestXor128.test(127));
+}
+
+TEST(BitsetTest, LeftShiftOperator) {
+ // Test shift by 0 (should be identity).
+ Bitset<64> A({0, 10, 20, 30});
+ Bitset<64> Result0 = A << 0;
+ EXPECT_TRUE(Result0 == A);
+
+ static_assert((Bitset<64>({0, 10}) << 0) == Bitset<64>({0, 10}));
+
+ // Test simple left shift.
+ Bitset<64> B({0, 10, 20});
+ Bitset<64> Result1 = B << 5;
+ EXPECT_TRUE(Result1.test(5));
+ EXPECT_TRUE(Result1.test(15));
+ EXPECT_TRUE(Result1.test(25));
+ EXPECT_FALSE(Result1.test(0));
+ EXPECT_FALSE(Result1.test(10));
+ EXPECT_FALSE(Result1.test(20));
+ EXPECT_EQ(Result1.count(), 3u);
+
+ constexpr Bitset<64> TestShift = Bitset<64>({0, 10, 20}) << 5;
+ static_assert(TestShift.test(5) && TestShift.test(15) && TestShift.test(25) &&
+ !TestShift.test(0));
+ static_assert(TestShift.count() == 3);
+
+ // Test shift across word boundary (32-bit and 64-bit).
+ Bitset<64> C({0, 31});
+ Bitset<64> Result2 = C << 1;
+ EXPECT_TRUE(Result2.test(1));
+ EXPECT_TRUE(Result2.test(32));
+ EXPECT_FALSE(Result2.test(0));
+ EXPECT_FALSE(Result2.test(31));
+
+ // Test word-aligned shift.
+ Bitset<128> D({0, 10, 20});
+ Bitset<128> Result3 = D << 64;
+ EXPECT_TRUE(Result3.test(64));
+ EXPECT_TRUE(Result3.test(74));
+ EXPECT_TRUE(Result3.test(84));
+ EXPECT_FALSE(Result3.test(0));
+ EXPECT_FALSE(Result3.test(10));
+ EXPECT_FALSE(Result3.test(20));
+
+ // Test shift that moves bits out of range.
+ Bitset<64> E({50, 60, 63});
+ Bitset<64> Result4 = E << 10;
+ EXPECT_TRUE(Result4.test(60));
+ EXPECT_EQ(Result4.count(), 1u);
+
+ static_assert((Bitset<64>({50, 60, 63}) << 10).count() == 1);
+
+ // Test shift by NumBits or more (should result in all zeros).
+ Bitset<64> F({0, 10, 20, 30});
+ Bitset<64> Result5 = F << 64;
+ EXPECT_TRUE(Result5.none());
+
+ static_assert((Bitset<64>({0, 10}) << 64).none());
+
+ Bitset<64> G({0, 10, 20, 30});
+ Bitset<64> Result6 = G << 100;
+ EXPECT_TRUE(Result6.none());
+
+ // Test with non-multiple of word size.
+ Bitset<33> H({0, 16, 32});
+ Bitset<33> Result7 = H << 1;
+ EXPECT_TRUE(Result7.test(1));
+ EXPECT_TRUE(Result7.test(17));
+ EXPECT_EQ(Result7.count(), 2u);
+
+ static_assert((Bitset<33>({0, 16, 32}) << 1).count() == 2);
+ static_assert(Bitset<64>().count() == 0);
+ static_assert(Bitset<64>().set().count() == 64);
+ static_assert(Bitset<128>({0, 10, 64, 127}).count() == 4);
+}
+
+TEST(BitsetTest, LeftShiftAssignOperator) {
+ // Test simple left shift assignment.
+ Bitset<64> A({0, 10, 20});
+ A <<= 5;
+ EXPECT_TRUE(A.test(5));
+ EXPECT_TRUE(A.test(15));
+ EXPECT_TRUE(A.test(25));
+ EXPECT_FALSE(A.test(0));
+ EXPECT_EQ(A.count(), 3u);
+
+ constexpr Bitset<64> TestShiftAssign = [] {
+ Bitset<64> X({0, 10});
+ X <<= 5;
+ return X;
+ }();
+ static_assert(TestShiftAssign.test(5) && TestShiftAssign.test(15));
+
+ // Test chained operations.
+ Bitset<64> B({0});
+ B <<= 1;
+ B <<= 2;
+ EXPECT_TRUE(B.test(3));
+ EXPECT_FALSE(B.test(0));
+ EXPECT_FALSE(B.test(1));
+
+ // Test shift by 0.
+ Bitset<64> C({5, 10, 15});
+ Bitset<64> Original = C;
+ C <<= 0;
+ EXPECT_TRUE(C == Original);
+
+ // Test shift to all zeros.
+ Bitset<64> D({0, 10, 20});
+ D <<= 64;
+ EXPECT_TRUE(D.none());
+}
+
+TEST(BitsetTest, RightShiftOperator) {
+ // Test shift by 0 (should be identity).
+ Bitset<64> A({10, 20, 30, 40});
+ Bitset<64> Result0 = A >> 0;
+ EXPECT_TRUE(Result0 == A);
+
+ static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));
+
+ // Test simple right shift.
+ Bitset<64> B({10, 20, 30});
+ Bitset<64> Result1 = B >> 5;
+ EXPECT_TRUE(Result1.test(5));
+ EXPECT_TRUE(Result1.test(15));
+ EXPECT_TRUE(Result1.test(25));
+ EXPECT_FALSE(Result1.test(10));
+ EXPECT_FALSE(Result1.test(20));
+ EXPECT_FALSE(Result1.test(30));
+ EXPECT_EQ(Result1.count(), 3u);
+
+ constexpr Bitset<64> TestRShift = Bitset<64>({10, 20, 30}) >> 5;
+ static_assert(TestRShift.test(5) && TestRShift.test(15) &&
+ TestRShift.test(25) && !TestRShift.test(10));
+
+ // Test shift across word boundary.
+ Bitset<64> C({32, 33});
+ Bitset<64> Result2 = C >> 1;
+ EXPECT_TRUE(Result2.test(31));
+ EXPECT_TRUE(Result2.test(32));
+ EXPECT_FALSE(Result2.test(33));
+
+ // Test word-aligned shift.
+ Bitset<128> D({64, 74, 84});
+ Bitset<128> Result3 = D >> 64;
+ EXPECT_TRUE(Result3.test(0));
+ EXPECT_TRUE(Result3.test(10));
+ EXPECT_TRUE(Result3.test(20));
+ EXPECT_FALSE(Result3.test(64));
+ EXPECT_FALSE(Result3.test(74));
+
+ // Test shift that moves bits out of range.
+ Bitset<64> E({0, 5, 10});
+ Bitset<64> Result4 = E >> 8;
+ EXPECT_TRUE(Result4.test(2));
+ EXPECT_FALSE(Result4.test(0));
+ EXPECT_FALSE(Result4.test(5));
+ EXPECT_EQ(Result4.count(), 1u);
+
+ // Test shift by NumBits or more (should result in all zeros).
+ Bitset<64> F({10, 20, 30, 40});
+ Bitset<64> Result5 = F >> 64;
+ EXPECT_TRUE(Result5.none());
+
+ static_assert((Bitset<64>({10, 20}) >> 64).none());
+
+ Bitset<64> G({10, 20, 30, 40});
+ Bitset<64> Result6 = G >> 100;
+ EXPECT_TRUE(Result6.none());
+
+ // Test with non-multiple of word size.
+ Bitset<33> H({1, 17, 32});
+ Bitset<33> Result7 = H >> 1;
+ EXPECT_TRUE(Result7.test(0));
+ EXPECT_TRUE(Result7.test(16));
+ EXPECT_TRUE(Result7.test(31));
+ EXPECT_EQ(Result7.count(), 3u);
+
+ // Test right shift of the last bit.
+ Bitset<64> I({63});
+ Bitset<64> Result8 = I >> 1;
+ EXPECT_TRUE(Result8.test(62));
+ EXPECT_FALSE(Result8.test(63));
+}
+
+TEST(BitsetTest, RightShiftAssignOperator) {
+ // Test simple right shift assignment.
+ Bitset<64> A({10, 20, 30});
+ A >>= 5;
+ EXPECT_TRUE(A.test(5));
+ EXPECT_TRUE(A.test(15));
+ EXPECT_TRUE(A.test(25));
+ EXPECT_FALSE(A.test(10));
+ EXPECT_EQ(A.count(), 3u);
+
+ constexpr Bitset<64> TestRShiftAssign = [] {
+ Bitset<64> X({10, 20});
+ X >>= 5;
+ return X;
+ }();
+ static_assert(TestRShiftAssign.test(5) && TestRShiftAssign.test(15));
+
+ // Test chained operations.
+ Bitset<64> B({8});
+ B >>= 1;
+ B >>= 2;
+ EXPECT_TRUE(B.test(5));
+ EXPECT_FALSE(B.test(7));
+ EXPECT_FALSE(B.test(8));
+
+ // Test shift by 0.
+ Bitset<64> C({5, 10, 15});
+ Bitset<64> Original = C;
+ C >>= 0;
+ EXPECT_TRUE(C == Original);
+
+ // Test shift to all zeros.
+ Bitset<64> D({10, 20, 30});
+ D >>= 64;
+ EXPECT_TRUE(D.none());
+}
+
+TEST(BitsetTest, ShiftEdgeCases) {
+ // Test shift at exact word boundaries.
+ Bitset<96> A({31, 32, 63, 64});
+ Bitset<96> Result1 = A << 32;
+ EXPECT_TRUE(Result1.test(63));
+ EXPECT_TRUE(Result1.test(64));
+ EXPECT_TRUE(Result1.test(95));
+ // 64 << 32 = 96 which is out of range for Bitset<96>.
+ EXPECT_EQ(Result1.count(), 3u);
+
+ constexpr Bitset<128> TestWordShift = Bitset<128>({64, 74}) >> 64;
+ static_assert(TestWordShift.test(0) && TestWordShift.test(10));
+
+ // Test shift at exact word boundaries for 64-bit systems.
+ Bitset<128> B({63, 64, 65});
+ Bitset<128> Result2 = B...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172062
More information about the llvm-commits
mailing list