[llvm] d067014 - [APInt] Added APInt::clearBits() method (#137098)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 04:41:07 PDT 2025


Author: Liam Semeria
Date: 2025-05-19T12:41:04+01:00
New Revision: d067014f13871642888afde850cdc558c32f350c

URL: https://github.com/llvm/llvm-project/commit/d067014f13871642888afde850cdc558c32f350c
DIFF: https://github.com/llvm/llvm-project/commit/d067014f13871642888afde850cdc558c32f350c.diff

LOG: [APInt] Added APInt::clearBits() method (#137098)

Added APInt::clearBits(unsigned loBit, unsigned hiBit) that clears bits within a certain range.

Fixes #136550

---------

Co-authored-by: Simon Pilgrim <llvm-dev at redking.me.uk>

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APInt.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Support/APInt.cpp
    llvm/unittests/ADT/APIntTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 44260c7eca309..afdd2cfddb12b 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1412,6 +1412,25 @@ class [[nodiscard]] APInt {
       U.pVal[whichWord(BitPosition)] &= Mask;
   }
 
+  /// Clear the bits from LoBit (inclusive) to HiBit (exclusive) to 0.
+  /// This function handles case when \p LoBit <= \p HiBit.
+  void clearBits(unsigned LoBit, unsigned HiBit) {
+    assert(HiBit <= BitWidth && "HiBit out of range");
+    assert(LoBit <= HiBit && "LoBit greater than HiBit");
+    if (LoBit == HiBit)
+      return;
+    if (HiBit <= APINT_BITS_PER_WORD) {
+      uint64_t Mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - (HiBit - LoBit));
+      Mask = ~(Mask << LoBit);
+      if (isSingleWord())
+        U.VAL &= Mask;
+      else
+        U.pVal[0] &= Mask;
+    } else {
+      clearBitsSlowCase(LoBit, HiBit);
+    }
+  }
+
   /// Set bottom loBits bits to 0.
   void clearLowBits(unsigned loBits) {
     assert(loBits <= BitWidth && "More bits than bitwidth");
@@ -2052,6 +2071,9 @@ class [[nodiscard]] APInt {
   /// out-of-line slow case for setBits.
   void setBitsSlowCase(unsigned loBit, unsigned hiBit);
 
+  /// out-of-line slow case for clearBits.
+  void clearBitsSlowCase(unsigned LoBit, unsigned HiBit);
+
   /// out-of-line slow case for flipAllBits.
   void flipAllBitsSlowCase();
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5d640c39a56d5..6ae22f49c3f21 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3546,7 +3546,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
     APInt DemandedSrcElts = DemandedElts;
-    DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+    DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
 
     Known.One.setAllBits();
     Known.Zero.setAllBits();
@@ -5230,7 +5230,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
     APInt DemandedSrcElts = DemandedElts;
-    DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+    DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
 
     Tmp = std::numeric_limits<unsigned>::max();
     if (!!DemandedSubElts) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index da999b5057d49..22b962958fc35 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1290,7 +1290,7 @@ bool TargetLowering::SimplifyDemandedBits(
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
     APInt DemandedSrcElts = DemandedElts;
-    DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+    DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
 
     KnownBits KnownSub, KnownSrc;
     if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
@@ -3357,7 +3357,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
     APInt DemandedSrcElts = DemandedElts;
-    DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+    DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
 
     APInt SubUndef, SubZero;
     if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,

diff  --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 4e45416b4598f..0119cb2f6e1f7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -336,6 +336,33 @@ void APInt::setBitsSlowCase(unsigned loBit, unsigned hiBit) {
     U.pVal[word] = WORDTYPE_MAX;
 }
 
+void APInt::clearBitsSlowCase(unsigned LoBit, unsigned HiBit) {
+  unsigned LoWord = whichWord(LoBit);
+  unsigned HiWord = whichWord(HiBit);
+
+  // Create an initial mask for the low word with ones below loBit.
+  uint64_t LoMask = ~(WORDTYPE_MAX << whichBit(LoBit));
+
+  // If HiBit is not aligned, we need a high mask.
+  unsigned HiShiftAmt = whichBit(HiBit);
+  if (HiShiftAmt != 0) {
+    // Create a high mask with ones above HiBit.
+    uint64_t HiMask = ~(WORDTYPE_MAX >> (APINT_BITS_PER_WORD - HiShiftAmt));
+    // If LoWord and HiWord are equal, then we combine the masks. Otherwise,
+    // set the bits in HiWord.
+    if (HiWord == LoWord)
+      LoMask &= HiMask;
+    else
+      U.pVal[HiWord] &= HiMask;
+  }
+  // Apply the mask to the low word.
+  U.pVal[LoWord] &= LoMask;
+
+  // Fill any words between LoWord and HiWord with all zeros.
+  for (unsigned Word = LoWord + 1; Word < HiWord; ++Word)
+    U.pVal[Word] = 0;
+}
+
 // Complement a bignum in-place.
 static void tcComplement(APInt::WordType *dst, unsigned parts) {
   for (unsigned i = 0; i < parts; i++)

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index b14366eac2185..a58fbd6deffa5 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2520,6 +2520,70 @@ TEST(APIntTest, setAllBits) {
   EXPECT_EQ(128u, i128.popcount());
 }
 
+TEST(APIntTest, clearBits) {
+  APInt i32 = APInt::getAllOnes(32);
+  i32.clearBits(1, 3);
+  EXPECT_EQ(1u, i32.countr_one());
+  EXPECT_EQ(0u, i32.countr_zero());
+  EXPECT_EQ(32u, i32.getActiveBits());
+  EXPECT_EQ(0u, i32.countl_zero());
+  EXPECT_EQ(29u, i32.countl_one());
+  EXPECT_EQ(30u, i32.popcount());
+
+  i32.clearBits(15, 15);
+  EXPECT_EQ(1u, i32.countr_one());
+  EXPECT_EQ(0u, i32.countr_zero());
+  EXPECT_EQ(32u, i32.getActiveBits());
+  EXPECT_EQ(0u, i32.countl_zero());
+  EXPECT_EQ(29u, i32.countl_one());
+  EXPECT_EQ(30u, i32.popcount());
+
+  i32.clearBits(28, 31);
+  EXPECT_EQ(1u, i32.countr_one());
+  EXPECT_EQ(0u, i32.countr_zero());
+  EXPECT_EQ(32u, i32.getActiveBits());
+  EXPECT_EQ(0u, i32.countl_zero());
+  EXPECT_EQ(1u, i32.countl_one());
+  EXPECT_EQ(27u, i32.popcount());
+  EXPECT_EQ(APInt(32, "8FFFFFF9", 16), i32);
+
+  APInt i256 = APInt::getAllOnes(256);
+  i256.clearBits(10, 250);
+  EXPECT_EQ(10u, i256.countr_one());
+  EXPECT_EQ(0u, i256.countr_zero());
+  EXPECT_EQ(256u, i256.getActiveBits());
+  EXPECT_EQ(0u, i256.countl_zero());
+  EXPECT_EQ(6u, i256.countl_one());
+  EXPECT_EQ(16u, i256.popcount());
+
+  APInt i311 = APInt::getAllOnes(311);
+  i311.clearBits(33, 99);
+  EXPECT_EQ(33u, i311.countr_one());
+  EXPECT_EQ(0u, i311.countr_zero());
+  EXPECT_EQ(311u, i311.getActiveBits());
+  EXPECT_EQ(0u, i311.countl_zero());
+  EXPECT_EQ(212u, i311.countl_one());
+  EXPECT_EQ(245u, i311.popcount());
+
+  APInt i64hi32 = APInt::getAllOnes(64);
+  i64hi32.clearBits(0, 32);
+  EXPECT_EQ(32u, i64hi32.countl_one());
+  EXPECT_EQ(0u, i64hi32.countl_zero());
+  EXPECT_EQ(64u, i64hi32.getActiveBits());
+  EXPECT_EQ(32u, i64hi32.countr_zero());
+  EXPECT_EQ(0u, i64hi32.countr_one());
+  EXPECT_EQ(32u, i64hi32.popcount());
+
+  i64hi32 = APInt::getAllOnes(64);
+  i64hi32.clearBits(32, 64);
+  EXPECT_EQ(32u, i64hi32.countr_one());
+  EXPECT_EQ(0u, i64hi32.countr_zero());
+  EXPECT_EQ(32u, i64hi32.getActiveBits());
+  EXPECT_EQ(32u, i64hi32.countl_zero());
+  EXPECT_EQ(0u, i64hi32.countl_one());
+  EXPECT_EQ(32u, i64hi32.popcount());
+}
+
 TEST(APIntTest, getLoBits) {
   APInt i32(32, 0xfa);
   i32.setHighBits(1);


        


More information about the llvm-commits mailing list