[llvm] Refactor Bitset to Be More Constexpr-Usable and Add More Member Functions (PR #172062)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 10:24:49 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-adt

Author: Jiachen Yuan (JiachenYuan)

<details>
<summary>Changes</summary>

This patch refactors some essential `Bitset` member functions to be `constexpr` and adds more useful member functions. Unit tests have been added to `BitsetTest.cpp` to cover both runtime and `consteval` context correctness.

---

Patch is 22.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172062.diff


2 Files Affected:

- (modified) llvm/include/llvm/ADT/Bitset.h (+109-15) 
- (modified) llvm/unittests/ADT/BitsetTest.cpp (+543) 


``````````diff
diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 1d4cbf8306230..f963e10646dce 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -16,7 +16,7 @@
 #ifndef LLVM_ADT_BITSET_H
 #define LLVM_ADT_BITSET_H
 
-#include <llvm/ADT/STLExtras.h>
+#include "llvm/ADT/bit.h"
 #include <array>
 #include <climits>
 #include <cstdint>
@@ -31,6 +31,10 @@ class Bitset {
   using BitWord = uintptr_t;
 
   static constexpr unsigned BitwordBits = sizeof(BitWord) * CHAR_BIT;
+  static constexpr unsigned RemainderNumBits = NumBits % BitwordBits;
+  static constexpr BitWord RemainderMask =
+      RemainderNumBits == 0 ? ~BitWord(0)
+                            : ((BitWord(1) << RemainderNumBits) - 1);
 
   static_assert(BitwordBits == 64 || BitwordBits == 32,
                 "Unsupported word size");
@@ -41,6 +45,11 @@ class Bitset {
   using StorageType = std::array<BitWord, NumWords>;
   StorageType Bits{};
 
+  constexpr void maskLastWord() {
+    if constexpr (RemainderNumBits != 0)
+      Bits[NumWords - 1] &= RemainderMask;
+  }
+
 protected:
   constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
     if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
@@ -52,12 +61,13 @@ class Bitset {
         uint64_t Elt = B[I];
         // On a 32-bit system the storage type will be 32-bit, so we may only
         // need half of a uint64_t.
-        for (size_t offset = 0; offset != 2 && BitsToAssign; ++offset) {
-          Bits[2 * I + offset] = static_cast<uint32_t>(Elt >> (32 * offset));
+        for (size_t Offset = 0; Offset != 2 && BitsToAssign; ++Offset) {
+          Bits[2 * I + Offset] = static_cast<uint32_t>(Elt >> (32 * Offset));
           BitsToAssign = BitsToAssign >= 32 ? BitsToAssign - 32 : 0;
         }
       }
     }
+    maskLastWord();
   }
 
 public:
@@ -67,8 +77,11 @@ class Bitset {
       set(I);
   }
 
-  Bitset &set() {
-    llvm::fill(Bits, -BitWord(0));
+  constexpr Bitset &set() {
+    constexpr const BitWord AllOnes = ~BitWord(0);
+    for (BitWord &B : Bits)
+      B = AllOnes;
+    maskLastWord();
     return *this;
   }
 
@@ -96,14 +109,28 @@ class Bitset {
 
   constexpr size_t size() const { return NumBits; }
 
-  bool any() const {
-    return llvm::any_of(Bits, [](BitWord I) { return I != 0; });
+  constexpr bool any() const {
+    for (unsigned I = 0; I < NumWords - 1; ++I)
+      if (Bits[I] != 0)
+        return true;
+    return (Bits[NumWords - 1] & RemainderMask) != 0;
   }
-  bool none() const { return !any(); }
-  size_t count() const {
+
+  constexpr bool none() const { return !any(); }
+
+  constexpr bool all() const {
+    constexpr const BitWord AllOnes = ~BitWord(0);
+    for (unsigned I = 0; I < NumWords - 1; ++I)
+      if (Bits[I] != AllOnes)
+        return false;
+    return (Bits[NumWords - 1] & RemainderMask) == RemainderMask;
+  }
+
+  constexpr size_t count() const {
     size_t Count = 0;
-    for (auto B : Bits)
-      Count += llvm::popcount(B);
+    for (unsigned I = 0; I < NumWords - 1; ++I)
+      Count += popcount(Bits[I]);
+    Count += popcount(Bits[NumWords - 1] & RemainderMask);
     return Count;
   }
 
@@ -146,16 +173,21 @@ class Bitset {
     Bitset Result = *this;
     for (auto &B : Result.Bits)
       B = ~B;
+    Result.maskLastWord();
     return Result;
   }
 
-  bool operator==(const Bitset &RHS) const {
-    return std::equal(std::begin(Bits), std::end(Bits), std::begin(RHS.Bits));
+  constexpr bool operator==(const Bitset &RHS) const {
+    for (unsigned I = 0; I < NumWords - 1; ++I)
+      if (Bits[I] != RHS.Bits[I])
+        return false;
+    return (Bits[NumWords - 1] & RemainderMask) ==
+           (RHS.Bits[NumWords - 1] & RemainderMask);
   }
 
-  bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
+  constexpr bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
 
-  bool operator < (const Bitset &Other) const {
+  constexpr bool operator<(const Bitset &Other) const {
     for (unsigned I = 0, E = size(); I != E; ++I) {
       bool LHS = test(I), RHS = Other.test(I);
       if (LHS != RHS)
@@ -163,6 +195,68 @@ class Bitset {
     }
     return false;
   }
+
+  constexpr Bitset &operator<<=(unsigned N) {
+    if (N == 0)
+      return *this;
+    if (N >= NumBits) {
+      return *this = Bitset();
+    }
+    const unsigned WordShift = N / BitwordBits;
+    const unsigned BitShift = N % BitwordBits;
+    if (BitShift == 0) {
+      for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
+        Bits[I] = Bits[I - WordShift];
+    } else {
+      const unsigned CarryShift = BitwordBits - BitShift;
+      for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
+        Bits[I] = (Bits[I - WordShift] << BitShift) |
+                  (Bits[I - WordShift - 1] >> CarryShift);
+      }
+      Bits[WordShift] = Bits[0] << BitShift;
+    }
+    for (unsigned I = 0; I < WordShift; ++I)
+      Bits[I] = 0;
+    maskLastWord();
+    return *this;
+  }
+
+  constexpr Bitset operator<<(unsigned N) const {
+    Bitset Result(*this);
+    Result <<= N;
+    return Result;
+  }
+
+  constexpr Bitset &operator>>=(unsigned N) {
+    if (N == 0)
+      return *this;
+    if (N >= NumBits) {
+      return *this = Bitset();
+    }
+    const unsigned WordShift = N / BitwordBits;
+    const unsigned BitShift = N % BitwordBits;
+    if (BitShift == 0) {
+      for (unsigned I = 0; I < NumWords - WordShift; ++I)
+        Bits[I] = Bits[I + WordShift];
+    } else {
+      const unsigned CarryShift = BitwordBits - BitShift;
+      for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
+        Bits[I] = (Bits[I + WordShift] >> BitShift) |
+                  (Bits[I + WordShift + 1] << CarryShift);
+      }
+      Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
+    }
+    for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
+      Bits[I] = 0;
+    maskLastWord();
+    return *this;
+  }
+
+  constexpr Bitset operator>>(unsigned N) const {
+    Bitset Result(*this);
+    Result >>= N;
+    return Result;
+  }
 };
 
 } // end namespace llvm
diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 0ecd213d6a781..c0013ce385e7e 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -68,4 +68,547 @@ TEST(BitsetTest, Construction) {
   EXPECT_TRUE(Test33.verifyValue(TestSingleVal));
   Test33.verifyStorageSize(1, 2);
 }
+
+TEST(BitsetTest, SetAndQuery) {
+  // Test set() with all bits.
+  Bitset<64> A;
+  A.set();
+  EXPECT_TRUE(A.all());
+  EXPECT_TRUE(A.any());
+  EXPECT_FALSE(A.none());
+
+  static_assert(Bitset<64>().set().all());
+  static_assert(Bitset<33>().set().all());
+
+  // Test set() with single bit.
+  Bitset<64> B;
+  B.set(10);
+  B.set(20);
+  EXPECT_TRUE(B.test(10));
+  EXPECT_TRUE(B.test(20));
+  EXPECT_FALSE(B.test(15));
+
+  static_assert(Bitset<64>().set(10).test(10));
+  static_assert(Bitset<64>().set(0).set(63).test(0) &&
+                Bitset<64>().set(0).set(63).test(63));
+  static_assert(Bitset<33>().set(32).test(32));
+  static_assert(Bitset<128>().set(64).set(127).test(64) &&
+                Bitset<128>().set(64).set(127).test(127));
+
+  // Test reset() with single bit.
+  Bitset<64> C({10, 20, 30});
+  C.reset(20);
+  EXPECT_TRUE(C.test(10));
+  EXPECT_FALSE(C.test(20));
+  EXPECT_TRUE(C.test(30));
+
+  static_assert(!Bitset<64>({10, 20}).reset(10).test(10));
+  static_assert(Bitset<64>({10, 20}).reset(10).test(20));
+  static_assert(!Bitset<96>({31, 32, 63}).reset(32).test(32));
+  static_assert(Bitset<33>({0, 32}).reset(0).test(32));
+
+  // Test flip() with single bit.
+  Bitset<64> D({10, 20});
+  D.flip(10);
+  D.flip(30);
+  EXPECT_FALSE(D.test(10));
+  EXPECT_TRUE(D.test(20));
+  EXPECT_TRUE(D.test(30));
+
+  static_assert(!Bitset<64>({10, 20}).flip(10).test(10));
+  static_assert(Bitset<64>({10, 20}).flip(30).test(30));
+  static_assert(Bitset<100>({50, 99}).flip(50).test(99) &&
+                !Bitset<100>({50, 99}).flip(50).test(50));
+  static_assert(Bitset<33>().flip(32).test(32));
+
+  // Test operator[].
+  Bitset<64> E({5, 15, 25});
+  EXPECT_TRUE(E[5]);
+  EXPECT_FALSE(E[10]);
+  EXPECT_TRUE(E[15]);
+
+  static_assert(Bitset<64>({10, 20})[10]);
+  static_assert(!Bitset<64>({10, 20})[15]);
+  static_assert(Bitset<128>({127})[127]);
+  static_assert(Bitset<96>({63, 64})[63] && Bitset<96>({63, 64})[64]);
+
+  // Test size().
+  EXPECT_EQ(A.size(), 64u);
+  Bitset<33> F;
+  EXPECT_EQ(F.size(), 33u);
+
+  static_assert(Bitset<64>().size() == 64);
+  static_assert(Bitset<128>().size() == 128);
+  static_assert(Bitset<33>().size() == 33);
+
+  // Test any() and none().
+  static_assert(!Bitset<64>().any());
+  static_assert(Bitset<64>().none());
+  static_assert(Bitset<64>({10}).any());
+  static_assert(!Bitset<64>({10}).none());
+}
+
+TEST(BitsetTest, ComparisonOperators) {
+  // Test operator==.
+  Bitset<64> A({10, 20, 30});
+  Bitset<64> B({10, 20, 30});
+  Bitset<64> C({10, 20, 31});
+  EXPECT_TRUE(A == B);
+  EXPECT_FALSE(A == C);
+
+  static_assert(Bitset<64>({10, 20}) == Bitset<64>({10, 20}));
+  static_assert(Bitset<64>({10, 20}) != Bitset<64>({10, 21}));
+
+  // Test operator< (lexicographic comparison, bit 0 is least significant).
+  static_assert(Bitset<64>({5, 11}) <
+                Bitset<64>({5, 10})); // At bit 10: A=0, B=1.
+  static_assert(!(Bitset<64>({5, 10}) < Bitset<64>({5, 10})));
+}
+
+TEST(BitsetTest, BitwiseNot) {
+  // Test operator~.
+  Bitset<64> A;
+  A.set();
+  Bitset<64> B = ~A;
+  EXPECT_TRUE(B.none());
+
+  static_assert((~Bitset<64>()).all());
+  static_assert((~Bitset<64>().set()).none());
+  static_assert((~Bitset<33>().set()).none());
+}
+
+TEST(BitsetTest, BitwiseOperators) {
+  // Test operator&.
+  Bitset<64> A({10, 20, 30});
+  Bitset<64> B({20, 30, 40});
+  Bitset<64> Result1 = A & B;
+  EXPECT_FALSE(Result1.test(10));
+  EXPECT_TRUE(Result1.test(20));
+  EXPECT_TRUE(Result1.test(30));
+  EXPECT_FALSE(Result1.test(40));
+  EXPECT_EQ(Result1.count(), 2u);
+
+  static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(20));
+  static_assert(!(Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(10));
+  static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).count() == 1);
+  static_assert(
+      (Bitset<96>({31, 32, 63, 64}) & Bitset<96>({32, 64, 95})).count() == 2);
+  static_assert((Bitset<33>({0, 32}) & Bitset<33>({32})).test(32));
+
+  // Test operator&=.
+  Bitset<64> C({10, 20, 30});
+  C &= Bitset<64>({20, 30, 40});
+  EXPECT_FALSE(C.test(10));
+  EXPECT_TRUE(C.test(20));
+  EXPECT_TRUE(C.test(30));
+  EXPECT_FALSE(C.test(40));
+
+  constexpr Bitset<64> TestAnd = [] {
+    Bitset<64> X({10, 20, 30});
+    X &= Bitset<64>({20, 30, 40});
+    return X;
+  }();
+  static_assert(TestAnd.test(20) && TestAnd.test(30) && !TestAnd.test(10));
+
+  constexpr Bitset<100> TestAnd100 = [] {
+    Bitset<100> X({10, 50, 99});
+    X &= Bitset<100>({50, 99});
+    return X;
+  }();
+  static_assert(TestAnd100.count() == 2 && TestAnd100.test(50) &&
+                TestAnd100.test(99));
+
+  // Test operator|.
+  Bitset<64> D({10, 20});
+  Bitset<64> E({20, 30});
+  Bitset<64> Result2 = D | E;
+  EXPECT_TRUE(Result2.test(10));
+  EXPECT_TRUE(Result2.test(20));
+  EXPECT_TRUE(Result2.test(30));
+  EXPECT_EQ(Result2.count(), 3u);
+
+  static_assert((Bitset<64>({10}) | Bitset<64>({20})).count() == 2);
+  static_assert((Bitset<128>({0, 64, 127}) | Bitset<128>({64, 100})).count() ==
+                4);
+  static_assert((Bitset<33>({0, 16}) | Bitset<33>({16, 32})).count() == 3);
+
+  // Test operator|=.
+  Bitset<64> F({10, 20});
+  F |= Bitset<64>({20, 30});
+  EXPECT_TRUE(F.test(10));
+  EXPECT_TRUE(F.test(20));
+  EXPECT_TRUE(F.test(30));
+
+  constexpr Bitset<64> TestOr = [] {
+    Bitset<64> X({10});
+    X |= Bitset<64>({20});
+    return X;
+  }();
+  static_assert(TestOr.test(10) && TestOr.test(20));
+
+  constexpr Bitset<96> TestOr96 = [] {
+    Bitset<96> X({31, 63});
+    X |= Bitset<96>({32, 64});
+    return X;
+  }();
+  static_assert(TestOr96.count() == 4);
+
+  // Test operator^.
+  Bitset<64> G({10, 20, 30});
+  Bitset<64> H({20, 30, 40});
+  Bitset<64> Result3 = G ^ H;
+  EXPECT_TRUE(Result3.test(10));
+  EXPECT_FALSE(Result3.test(20));
+  EXPECT_FALSE(Result3.test(30));
+  EXPECT_TRUE(Result3.test(40));
+  EXPECT_EQ(Result3.count(), 2u);
+
+  static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(10));
+  static_assert(!(Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(20));
+  static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(30));
+  static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).count() == 2);
+  static_assert((Bitset<100>({0, 50, 99}) ^ Bitset<100>({50})).count() == 2);
+  static_assert((Bitset<33>({0, 32}) ^ Bitset<33>({0, 16})).count() == 2);
+
+  // Test operator^=.
+  Bitset<64> I({10, 20, 30});
+  I ^= Bitset<64>({20, 30, 40});
+  EXPECT_TRUE(I.test(10));
+  EXPECT_FALSE(I.test(20));
+  EXPECT_FALSE(I.test(30));
+  EXPECT_TRUE(I.test(40));
+
+  constexpr Bitset<64> TestXor = [] {
+    Bitset<64> X({10, 20});
+    X ^= Bitset<64>({20, 30});
+    return X;
+  }();
+  static_assert(TestXor.test(10) && !TestXor.test(20) && TestXor.test(30));
+
+  constexpr Bitset<128> TestXor128 = [] {
+    Bitset<128> X({0, 64, 127});
+    X ^= Bitset<128>({64});
+    return X;
+  }();
+  static_assert(TestXor128.count() == 2 && TestXor128.test(0) &&
+                TestXor128.test(127));
+}
+
+TEST(BitsetTest, LeftShiftOperator) {
+  // Test shift by 0 (should be identity).
+  Bitset<64> A({0, 10, 20, 30});
+  Bitset<64> Result0 = A << 0;
+  EXPECT_TRUE(Result0 == A);
+
+  static_assert((Bitset<64>({0, 10}) << 0) == Bitset<64>({0, 10}));
+
+  // Test simple left shift.
+  Bitset<64> B({0, 10, 20});
+  Bitset<64> Result1 = B << 5;
+  EXPECT_TRUE(Result1.test(5));
+  EXPECT_TRUE(Result1.test(15));
+  EXPECT_TRUE(Result1.test(25));
+  EXPECT_FALSE(Result1.test(0));
+  EXPECT_FALSE(Result1.test(10));
+  EXPECT_FALSE(Result1.test(20));
+  EXPECT_EQ(Result1.count(), 3u);
+
+  constexpr Bitset<64> TestShift = Bitset<64>({0, 10, 20}) << 5;
+  static_assert(TestShift.test(5) && TestShift.test(15) && TestShift.test(25) &&
+                !TestShift.test(0));
+  static_assert(TestShift.count() == 3);
+
+  // Test shift across word boundary (32-bit and 64-bit).
+  Bitset<64> C({0, 31});
+  Bitset<64> Result2 = C << 1;
+  EXPECT_TRUE(Result2.test(1));
+  EXPECT_TRUE(Result2.test(32));
+  EXPECT_FALSE(Result2.test(0));
+  EXPECT_FALSE(Result2.test(31));
+
+  // Test word-aligned shift.
+  Bitset<128> D({0, 10, 20});
+  Bitset<128> Result3 = D << 64;
+  EXPECT_TRUE(Result3.test(64));
+  EXPECT_TRUE(Result3.test(74));
+  EXPECT_TRUE(Result3.test(84));
+  EXPECT_FALSE(Result3.test(0));
+  EXPECT_FALSE(Result3.test(10));
+  EXPECT_FALSE(Result3.test(20));
+
+  // Test shift that moves bits out of range.
+  Bitset<64> E({50, 60, 63});
+  Bitset<64> Result4 = E << 10;
+  EXPECT_TRUE(Result4.test(60));
+  EXPECT_EQ(Result4.count(), 1u);
+
+  static_assert((Bitset<64>({50, 60, 63}) << 10).count() == 1);
+
+  // Test shift by NumBits or more (should result in all zeros).
+  Bitset<64> F({0, 10, 20, 30});
+  Bitset<64> Result5 = F << 64;
+  EXPECT_TRUE(Result5.none());
+
+  static_assert((Bitset<64>({0, 10}) << 64).none());
+
+  Bitset<64> G({0, 10, 20, 30});
+  Bitset<64> Result6 = G << 100;
+  EXPECT_TRUE(Result6.none());
+
+  // Test with non-multiple of word size.
+  Bitset<33> H({0, 16, 32});
+  Bitset<33> Result7 = H << 1;
+  EXPECT_TRUE(Result7.test(1));
+  EXPECT_TRUE(Result7.test(17));
+  EXPECT_EQ(Result7.count(), 2u);
+
+  static_assert((Bitset<33>({0, 16, 32}) << 1).count() == 2);
+  static_assert(Bitset<64>().count() == 0);
+  static_assert(Bitset<64>().set().count() == 64);
+  static_assert(Bitset<128>({0, 10, 64, 127}).count() == 4);
+}
+
+TEST(BitsetTest, LeftShiftAssignOperator) {
+  // Test simple left shift assignment.
+  Bitset<64> A({0, 10, 20});
+  A <<= 5;
+  EXPECT_TRUE(A.test(5));
+  EXPECT_TRUE(A.test(15));
+  EXPECT_TRUE(A.test(25));
+  EXPECT_FALSE(A.test(0));
+  EXPECT_EQ(A.count(), 3u);
+
+  constexpr Bitset<64> TestShiftAssign = [] {
+    Bitset<64> X({0, 10});
+    X <<= 5;
+    return X;
+  }();
+  static_assert(TestShiftAssign.test(5) && TestShiftAssign.test(15));
+
+  // Test chained operations.
+  Bitset<64> B({0});
+  B <<= 1;
+  B <<= 2;
+  EXPECT_TRUE(B.test(3));
+  EXPECT_FALSE(B.test(0));
+  EXPECT_FALSE(B.test(1));
+
+  // Test shift by 0.
+  Bitset<64> C({5, 10, 15});
+  Bitset<64> Original = C;
+  C <<= 0;
+  EXPECT_TRUE(C == Original);
+
+  // Test shift to all zeros.
+  Bitset<64> D({0, 10, 20});
+  D <<= 64;
+  EXPECT_TRUE(D.none());
+}
+
+TEST(BitsetTest, RightShiftOperator) {
+  // Test shift by 0 (should be identity).
+  Bitset<64> A({10, 20, 30, 40});
+  Bitset<64> Result0 = A >> 0;
+  EXPECT_TRUE(Result0 == A);
+
+  static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));
+
+  // Test simple right shift.
+  Bitset<64> B({10, 20, 30});
+  Bitset<64> Result1 = B >> 5;
+  EXPECT_TRUE(Result1.test(5));
+  EXPECT_TRUE(Result1.test(15));
+  EXPECT_TRUE(Result1.test(25));
+  EXPECT_FALSE(Result1.test(10));
+  EXPECT_FALSE(Result1.test(20));
+  EXPECT_FALSE(Result1.test(30));
+  EXPECT_EQ(Result1.count(), 3u);
+
+  constexpr Bitset<64> TestRShift = Bitset<64>({10, 20, 30}) >> 5;
+  static_assert(TestRShift.test(5) && TestRShift.test(15) &&
+                TestRShift.test(25) && !TestRShift.test(10));
+
+  // Test shift across word boundary.
+  Bitset<64> C({32, 33});
+  Bitset<64> Result2 = C >> 1;
+  EXPECT_TRUE(Result2.test(31));
+  EXPECT_TRUE(Result2.test(32));
+  EXPECT_FALSE(Result2.test(33));
+
+  // Test word-aligned shift.
+  Bitset<128> D({64, 74, 84});
+  Bitset<128> Result3 = D >> 64;
+  EXPECT_TRUE(Result3.test(0));
+  EXPECT_TRUE(Result3.test(10));
+  EXPECT_TRUE(Result3.test(20));
+  EXPECT_FALSE(Result3.test(64));
+  EXPECT_FALSE(Result3.test(74));
+
+  // Test shift that moves bits out of range.
+  Bitset<64> E({0, 5, 10});
+  Bitset<64> Result4 = E >> 8;
+  EXPECT_TRUE(Result4.test(2));
+  EXPECT_FALSE(Result4.test(0));
+  EXPECT_FALSE(Result4.test(5));
+  EXPECT_EQ(Result4.count(), 1u);
+
+  // Test shift by NumBits or more (should result in all zeros).
+  Bitset<64> F({10, 20, 30, 40});
+  Bitset<64> Result5 = F >> 64;
+  EXPECT_TRUE(Result5.none());
+
+  static_assert((Bitset<64>({10, 20}) >> 64).none());
+
+  Bitset<64> G({10, 20, 30, 40});
+  Bitset<64> Result6 = G >> 100;
+  EXPECT_TRUE(Result6.none());
+
+  // Test with non-multiple of word size.
+  Bitset<33> H({1, 17, 32});
+  Bitset<33> Result7 = H >> 1;
+  EXPECT_TRUE(Result7.test(0));
+  EXPECT_TRUE(Result7.test(16));
+  EXPECT_TRUE(Result7.test(31));
+  EXPECT_EQ(Result7.count(), 3u);
+
+  // Test right shift of the last bit.
+  Bitset<64> I({63});
+  Bitset<64> Result8 = I >> 1;
+  EXPECT_TRUE(Result8.test(62));
+  EXPECT_FALSE(Result8.test(63));
+}
+
+TEST(BitsetTest, RightShiftAssignOperator) {
+  // Test simple right shift assignment.
+  Bitset<64> A({10, 20, 30});
+  A >>= 5;
+  EXPECT_TRUE(A.test(5));
+  EXPECT_TRUE(A.test(15));
+  EXPECT_TRUE(A.test(25));
+  EXPECT_FALSE(A.test(10));
+  EXPECT_EQ(A.count(), 3u);
+
+  constexpr Bitset<64> TestRShiftAssign = [] {
+    Bitset<64> X({10, 20});
+    X >>= 5;
+    return X;
+  }();
+  static_assert(TestRShiftAssign.test(5) && TestRShiftAssign.test(15));
+
+  // Test chained operations.
+  Bitset<64> B({8});
+  B >>= 1;
+  B >>= 2;
+  EXPECT_TRUE(B.test(5));
+  EXPECT_FALSE(B.test(7));
+  EXPECT_FALSE(B.test(8));
+
+  // Test shift by 0.
+  Bitset<64> C({5, 10, 15});
+  Bitset<64> Original = C;
+  C >>= 0;
+  EXPECT_TRUE(C == Original);
+
+  // Test shift to all zeros.
+  Bitset<64> D({10, 20, 30});
+  D >>= 64;
+  EXPECT_TRUE(D.none());
+}
+
+TEST(BitsetTest, ShiftEdgeCases) {
+  // Test shift at exact word boundaries.
+  Bitset<96> A({31, 32, 63, 64});
+  Bitset<96> Result1 = A << 32;
+  EXPECT_TRUE(Result1.test(63));
+  EXPECT_TRUE(Result1.test(64));
+  EXPECT_TRUE(Result1.test(95));
+  // 64 << 32 = 96 which is out of range for Bitset<96>.
+  EXPECT_EQ(Result1.count(), 3u);
+
+  constexpr Bitset<128> TestWordShift = Bitset<128>({64, 74}) >> 64;
+  static_assert(TestWordShift.test(0) && TestWordShift.test(10));
+
+  // Test shift at exact word boundaries for 64-bit systems.
+  Bitset<128> B({63, 64, 65});
+  Bitset<128> Result2 = B...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172062


More information about the llvm-commits mailing list