[llvm] a71b1d2 - [ADT] Refactor Bitset to Be More Constexpr-Usable (#172062)

via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 20 04:26:01 PST 2025


Author: Jiachen Yuan
Date: 2025-12-20T07:25:57-05:00
New Revision: a71b1d2a385da0f691f76176ed4a13d35e1f970f

URL: https://github.com/llvm/llvm-project/commit/a71b1d2a385da0f691f76176ed4a13d35e1f970f
DIFF: https://github.com/llvm/llvm-project/commit/a71b1d2a385da0f691f76176ed4a13d35e1f970f.diff

LOG: [ADT] Refactor Bitset to Be More Constexpr-Usable (#172062)

This patch refactors some essential `Bitset` member functions to be
`constexpr` and adds more useful member functions. Unit tests have been
added to `BitsetTest.cpp` to cover both runtime and `consteval` context
correctness.

The thought of refactor was brought up in this context:
https://discourse.llvm.org/t/rfc-out-of-lanebitmask-bits-again/88613.

Added: 
    

Modified: 
    llvm/include/llvm/ADT/Bitset.h
    llvm/unittests/ADT/BitsetTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 1d4cbf8306230..09c1239f15d2f 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -16,7 +16,7 @@
 #ifndef LLVM_ADT_BITSET_H
 #define LLVM_ADT_BITSET_H
 
-#include <llvm/ADT/STLExtras.h>
+#include "llvm/ADT/STLExtras.h"
 #include <array>
 #include <climits>
 #include <cstdint>
@@ -31,6 +31,10 @@ class Bitset {
   using BitWord = uintptr_t;
 
   static constexpr unsigned BitwordBits = sizeof(BitWord) * CHAR_BIT;
+  static constexpr unsigned RemainderNumBits = NumBits % BitwordBits;
+  static constexpr BitWord RemainderMask =
+      RemainderNumBits == 0 ? ~BitWord(0)
+                            : ((BitWord(1) << RemainderNumBits) - 1);
 
   static_assert(BitwordBits == 64 || BitwordBits == 32,
                 "Unsupported word size");
@@ -38,9 +42,16 @@ class Bitset {
   static constexpr unsigned NumWords =
       (NumBits + BitwordBits - 1) / BitwordBits;
 
+  // Returns the index of the last word (0-based). The last word may be
+  // partially filled and requires masking to maintain the invariant that
+  // unused high bits are always zero.
+  static constexpr unsigned getLastWordIndex() { return NumWords - 1; }
+
   using StorageType = std::array<BitWord, NumWords>;
   StorageType Bits{};
 
+  constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }
+
 protected:
   constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
     if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
@@ -52,12 +63,13 @@ class Bitset {
         uint64_t Elt = B[I];
         // On a 32-bit system the storage type will be 32-bit, so we may only
         // need half of a uint64_t.
-        for (size_t offset = 0; offset != 2 && BitsToAssign; ++offset) {
-          Bits[2 * I + offset] = static_cast<uint32_t>(Elt >> (32 * offset));
+        for (size_t Offset = 0; Offset != 2 && BitsToAssign; ++Offset) {
+          Bits[2 * I + Offset] = static_cast<uint32_t>(Elt >> (32 * Offset));
           BitsToAssign = BitsToAssign >= 32 ? BitsToAssign - 32 : 0;
         }
       }
     }
+    maskLastWord();
   }
 
 public:
@@ -67,8 +79,11 @@ class Bitset {
       set(I);
   }
 
-  Bitset &set() {
-    llvm::fill(Bits, -BitWord(0));
+  constexpr Bitset &set() {
+    constexpr const BitWord AllOnes = ~BitWord(0);
+    for (BitWord &B : Bits)
+      B = AllOnes;
+    maskLastWord();
     return *this;
   }
 
@@ -96,14 +111,24 @@ class Bitset {
 
   constexpr size_t size() const { return NumBits; }
 
-  bool any() const {
+  constexpr bool any() const {
     return llvm::any_of(Bits, [](BitWord I) { return I != 0; });
   }
-  bool none() const { return !any(); }
-  size_t count() const {
+
+  constexpr bool none() const { return !any(); }
+
+  constexpr bool all() const {
+    constexpr const BitWord AllOnes = ~BitWord(0);
+    for (unsigned I = 0; I < getLastWordIndex(); ++I)
+      if (Bits[I] != AllOnes)
+        return false;
+    return Bits[getLastWordIndex()] == RemainderMask;
+  }
+
+  constexpr size_t count() const {
     size_t Count = 0;
-    for (auto B : Bits)
-      Count += llvm::popcount(B);
+    for (BitWord Word : Bits)
+      Count += popcount(Word);
     return Count;
   }
 
@@ -144,18 +169,22 @@ class Bitset {
 
   constexpr Bitset operator~() const {
     Bitset Result = *this;
-    for (auto &B : Result.Bits)
+    for (BitWord &B : Result.Bits)
       B = ~B;
+    Result.maskLastWord();
     return Result;
   }
 
-  bool operator==(const Bitset &RHS) const {
-    return std::equal(std::begin(Bits), std::end(Bits), std::begin(RHS.Bits));
+  constexpr bool operator==(const Bitset &RHS) const {
+    for (unsigned I = 0; I < NumWords; ++I)
+      if (Bits[I] != RHS.Bits[I])
+        return false;
+    return true;
   }
 
-  bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
+  constexpr bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
 
-  bool operator < (const Bitset &Other) const {
+  constexpr bool operator<(const Bitset &Other) const {
     for (unsigned I = 0, E = size(); I != E; ++I) {
       bool LHS = test(I), RHS = Other.test(I);
       if (LHS != RHS)

diff  --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 0ecd213d6a781..678197e31a379 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -68,4 +68,230 @@ TEST(BitsetTest, Construction) {
   EXPECT_TRUE(Test33.verifyValue(TestSingleVal));
   Test33.verifyStorageSize(1, 2);
 }
+
+TEST(BitsetTest, SetAndQuery) {
+  // Test set() with all bits.
+  Bitset<64> A;
+  A.set();
+  EXPECT_TRUE(A.all());
+  EXPECT_TRUE(A.any());
+  EXPECT_FALSE(A.none());
+
+  static_assert(Bitset<64>().set().all());
+  static_assert(Bitset<33>().set().all());
+
+  // Test set() with single bit.
+  Bitset<64> B;
+  B.set(10);
+  B.set(20);
+  EXPECT_TRUE(B.test(10));
+  EXPECT_TRUE(B.test(20));
+  EXPECT_FALSE(B.test(15));
+
+  static_assert(Bitset<64>().set(10).test(10));
+  static_assert(Bitset<64>().set(0).set(63).test(0) &&
+                Bitset<64>().set(0).set(63).test(63));
+  static_assert(Bitset<33>().set(32).test(32));
+  static_assert(Bitset<128>().set(64).set(127).test(64) &&
+                Bitset<128>().set(64).set(127).test(127));
+
+  // Test reset() with single bit.
+  Bitset<64> C({10, 20, 30});
+  C.reset(20);
+  EXPECT_TRUE(C.test(10));
+  EXPECT_FALSE(C.test(20));
+  EXPECT_TRUE(C.test(30));
+
+  static_assert(!Bitset<64>({10, 20}).reset(10).test(10));
+  static_assert(Bitset<64>({10, 20}).reset(10).test(20));
+  static_assert(!Bitset<96>({31, 32, 63}).reset(32).test(32));
+  static_assert(Bitset<33>({0, 32}).reset(0).test(32));
+
+  // Test flip() with single bit.
+  Bitset<64> D({10, 20});
+  D.flip(10);
+  D.flip(30);
+  EXPECT_FALSE(D.test(10));
+  EXPECT_TRUE(D.test(20));
+  EXPECT_TRUE(D.test(30));
+
+  static_assert(!Bitset<64>({10, 20}).flip(10).test(10));
+  static_assert(Bitset<64>({10, 20}).flip(30).test(30));
+  static_assert(Bitset<100>({50, 99}).flip(50).test(99) &&
+                !Bitset<100>({50, 99}).flip(50).test(50));
+  static_assert(Bitset<33>().flip(32).test(32));
+
+  // Test operator[].
+  Bitset<64> E({5, 15, 25});
+  EXPECT_TRUE(E[5]);
+  EXPECT_FALSE(E[10]);
+  EXPECT_TRUE(E[15]);
+
+  static_assert(Bitset<64>({10, 20})[10]);
+  static_assert(!Bitset<64>({10, 20})[15]);
+  static_assert(Bitset<128>({127})[127]);
+  static_assert(Bitset<96>({63, 64})[63] && Bitset<96>({63, 64})[64]);
+
+  // Test size().
+  EXPECT_EQ(A.size(), 64u);
+  Bitset<33> F;
+  EXPECT_EQ(F.size(), 33u);
+
+  static_assert(Bitset<64>().size() == 64);
+  static_assert(Bitset<128>().size() == 128);
+  static_assert(Bitset<33>().size() == 33);
+
+  // Test any() and none().
+  static_assert(!Bitset<64>().any());
+  static_assert(Bitset<64>().none());
+  static_assert(Bitset<64>({10}).any());
+  static_assert(!Bitset<64>({10}).none());
+}
+
+TEST(BitsetTest, ComparisonOperators) {
+  // Test operator==.
+  Bitset<64> A({10, 20, 30});
+  Bitset<64> B({10, 20, 30});
+  Bitset<64> C({10, 20, 31});
+  EXPECT_TRUE(A == B);
+  EXPECT_FALSE(A == C);
+
+  static_assert(Bitset<64>({10, 20}) == Bitset<64>({10, 20}));
+  static_assert(Bitset<64>({10, 20}) != Bitset<64>({10, 21}));
+
+  // Test operator< (lexicographic comparison, bit 0 is least significant).
+  static_assert(Bitset<64>({5, 11}) <
+                Bitset<64>({5, 10})); // At bit 10: A=0, B=1.
+  static_assert(!(Bitset<64>({5, 10}) < Bitset<64>({5, 10})));
+}
+
+TEST(BitsetTest, BitwiseNot) {
+  // Test operator~.
+  Bitset<64> A;
+  A.set();
+  Bitset<64> B = ~A;
+  EXPECT_TRUE(B.none());
+
+  static_assert((~Bitset<64>()).all());
+  static_assert((~Bitset<64>().set()).none());
+  static_assert((~Bitset<33>().set()).none());
+}
+
+TEST(BitsetTest, BitwiseOperators) {
+  // Test operator&.
+  Bitset<64> A({10, 20, 30});
+  Bitset<64> B({20, 30, 40});
+  Bitset<64> Result1 = A & B;
+  EXPECT_FALSE(Result1.test(10));
+  EXPECT_TRUE(Result1.test(20));
+  EXPECT_TRUE(Result1.test(30));
+  EXPECT_FALSE(Result1.test(40));
+  EXPECT_EQ(Result1.count(), 2u);
+
+  static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(20));
+  static_assert(!(Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(10));
+  static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).count() == 1);
+  static_assert(
+      (Bitset<96>({31, 32, 63, 64}) & Bitset<96>({32, 64, 95})).count() == 2);
+  static_assert((Bitset<33>({0, 32}) & Bitset<33>({32})).test(32));
+
+  // Test operator&=.
+  Bitset<64> C({10, 20, 30});
+  C &= Bitset<64>({20, 30, 40});
+  EXPECT_FALSE(C.test(10));
+  EXPECT_TRUE(C.test(20));
+  EXPECT_TRUE(C.test(30));
+  EXPECT_FALSE(C.test(40));
+
+  constexpr Bitset<64> TestAnd = [] {
+    Bitset<64> X({10, 20, 30});
+    X &= Bitset<64>({20, 30, 40});
+    return X;
+  }();
+  static_assert(TestAnd.test(20) && TestAnd.test(30) && !TestAnd.test(10));
+
+  constexpr Bitset<100> TestAnd100 = [] {
+    Bitset<100> X({10, 50, 99});
+    X &= Bitset<100>({50, 99});
+    return X;
+  }();
+  static_assert(TestAnd100.count() == 2 && TestAnd100.test(50) &&
+                TestAnd100.test(99));
+
+  // Test operator|.
+  Bitset<64> D({10, 20});
+  Bitset<64> E({20, 30});
+  Bitset<64> Result2 = D | E;
+  EXPECT_TRUE(Result2.test(10));
+  EXPECT_TRUE(Result2.test(20));
+  EXPECT_TRUE(Result2.test(30));
+  EXPECT_EQ(Result2.count(), 3u);
+
+  static_assert((Bitset<64>({10}) | Bitset<64>({20})).count() == 2);
+  static_assert((Bitset<128>({0, 64, 127}) | Bitset<128>({64, 100})).count() ==
+                4);
+  static_assert((Bitset<33>({0, 16}) | Bitset<33>({16, 32})).count() == 3);
+
+  // Test operator|=.
+  Bitset<64> F({10, 20});
+  F |= Bitset<64>({20, 30});
+  EXPECT_TRUE(F.test(10));
+  EXPECT_TRUE(F.test(20));
+  EXPECT_TRUE(F.test(30));
+
+  constexpr Bitset<64> TestOr = [] {
+    Bitset<64> X({10});
+    X |= Bitset<64>({20});
+    return X;
+  }();
+  static_assert(TestOr.test(10) && TestOr.test(20));
+
+  constexpr Bitset<96> TestOr96 = [] {
+    Bitset<96> X({31, 63});
+    X |= Bitset<96>({32, 64});
+    return X;
+  }();
+  static_assert(TestOr96.count() == 4);
+
+  // Test operator^.
+  Bitset<64> G({10, 20, 30});
+  Bitset<64> H({20, 30, 40});
+  Bitset<64> Result3 = G ^ H;
+  EXPECT_TRUE(Result3.test(10));
+  EXPECT_FALSE(Result3.test(20));
+  EXPECT_FALSE(Result3.test(30));
+  EXPECT_TRUE(Result3.test(40));
+  EXPECT_EQ(Result3.count(), 2u);
+
+  static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(10));
+  static_assert(!(Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(20));
+  static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(30));
+  static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).count() == 2);
+  static_assert((Bitset<100>({0, 50, 99}) ^ Bitset<100>({50})).count() == 2);
+  static_assert((Bitset<33>({0, 32}) ^ Bitset<33>({0, 16})).count() == 2);
+
+  // Test operator^=.
+  Bitset<64> I({10, 20, 30});
+  I ^= Bitset<64>({20, 30, 40});
+  EXPECT_TRUE(I.test(10));
+  EXPECT_FALSE(I.test(20));
+  EXPECT_FALSE(I.test(30));
+  EXPECT_TRUE(I.test(40));
+
+  constexpr Bitset<64> TestXor = [] {
+    Bitset<64> X({10, 20});
+    X ^= Bitset<64>({20, 30});
+    return X;
+  }();
+  static_assert(TestXor.test(10) && !TestXor.test(20) && TestXor.test(30));
+
+  constexpr Bitset<128> TestXor128 = [] {
+    Bitset<128> X({0, 64, 127});
+    X ^= Bitset<128>({64});
+    return X;
+  }();
+  static_assert(TestXor128.count() == 2 && TestXor128.test(0) &&
+                TestXor128.test(127));
+}
+
 } // namespace


        


More information about the llvm-commits mailing list