[llvm] d067014 - [APInt] Added APInt::clearBits() method (#137098)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 19 04:41:07 PDT 2025
Author: Liam Semeria
Date: 2025-05-19T12:41:04+01:00
New Revision: d067014f13871642888afde850cdc558c32f350c
URL: https://github.com/llvm/llvm-project/commit/d067014f13871642888afde850cdc558c32f350c
DIFF: https://github.com/llvm/llvm-project/commit/d067014f13871642888afde850cdc558c32f350c.diff
LOG: [APInt] Added APInt::clearBits() method (#137098)
Added APInt::clearBits(unsigned loBit, unsigned hiBit) that clears bits within a certain range.
Fixes #136550
---------
Co-authored-by: Simon Pilgrim <llvm-dev at redking.me.uk>
Added:
Modified:
llvm/include/llvm/ADT/APInt.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/Support/APInt.cpp
llvm/unittests/ADT/APIntTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 44260c7eca309..afdd2cfddb12b 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1412,6 +1412,25 @@ class [[nodiscard]] APInt {
U.pVal[whichWord(BitPosition)] &= Mask;
}
+ /// Clear the bits from LoBit (inclusive) to HiBit (exclusive) to 0.
+ /// This function handles case when \p LoBit <= \p HiBit.
+ void clearBits(unsigned LoBit, unsigned HiBit) {
+ assert(HiBit <= BitWidth && "HiBit out of range");
+ assert(LoBit <= HiBit && "LoBit greater than HiBit");
+ if (LoBit == HiBit)
+ return;
+ if (HiBit <= APINT_BITS_PER_WORD) {
+ uint64_t Mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - (HiBit - LoBit));
+ Mask = ~(Mask << LoBit);
+ if (isSingleWord())
+ U.VAL &= Mask;
+ else
+ U.pVal[0] &= Mask;
+ } else {
+ clearBitsSlowCase(LoBit, HiBit);
+ }
+ }
+
/// Set bottom loBits bits to 0.
void clearLowBits(unsigned loBits) {
assert(loBits <= BitWidth && "More bits than bitwidth");
@@ -2052,6 +2071,9 @@ class [[nodiscard]] APInt {
/// out-of-line slow case for setBits.
void setBitsSlowCase(unsigned loBit, unsigned hiBit);
+ /// out-of-line slow case for clearBits.
+ void clearBitsSlowCase(unsigned LoBit, unsigned HiBit);
+
/// out-of-line slow case for flipAllBits.
void flipAllBitsSlowCase();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5d640c39a56d5..6ae22f49c3f21 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3546,7 +3546,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
APInt DemandedSrcElts = DemandedElts;
- DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+ DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
Known.One.setAllBits();
Known.Zero.setAllBits();
@@ -5230,7 +5230,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
APInt DemandedSrcElts = DemandedElts;
- DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+ DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
Tmp = std::numeric_limits<unsigned>::max();
if (!!DemandedSubElts) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index da999b5057d49..22b962958fc35 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1290,7 +1290,7 @@ bool TargetLowering::SimplifyDemandedBits(
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
APInt DemandedSrcElts = DemandedElts;
- DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+ DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
KnownBits KnownSub, KnownSrc;
if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
@@ -3357,7 +3357,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
APInt DemandedSrcElts = DemandedElts;
- DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
+ DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
APInt SubUndef, SubZero;
if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 4e45416b4598f..0119cb2f6e1f7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -336,6 +336,33 @@ void APInt::setBitsSlowCase(unsigned loBit, unsigned hiBit) {
U.pVal[word] = WORDTYPE_MAX;
}
+void APInt::clearBitsSlowCase(unsigned LoBit, unsigned HiBit) {
+ unsigned LoWord = whichWord(LoBit);
+ unsigned HiWord = whichWord(HiBit);
+
+ // Create an initial mask for the low word with ones below loBit.
+ uint64_t LoMask = ~(WORDTYPE_MAX << whichBit(LoBit));
+
+ // If HiBit is not aligned, we need a high mask.
+ unsigned HiShiftAmt = whichBit(HiBit);
+ if (HiShiftAmt != 0) {
+ // Create a high mask with ones above HiBit.
+ uint64_t HiMask = ~(WORDTYPE_MAX >> (APINT_BITS_PER_WORD - HiShiftAmt));
+ // If LoWord and HiWord are equal, then we combine the masks. Otherwise,
+ // set the bits in HiWord.
+ if (HiWord == LoWord)
+ LoMask &= HiMask;
+ else
+ U.pVal[HiWord] &= HiMask;
+ }
+ // Apply the mask to the low word.
+ U.pVal[LoWord] &= LoMask;
+
+ // Fill any words between LoWord and HiWord with all zeros.
+ for (unsigned Word = LoWord + 1; Word < HiWord; ++Word)
+ U.pVal[Word] = 0;
+}
+
// Complement a bignum in-place.
static void tcComplement(APInt::WordType *dst, unsigned parts) {
for (unsigned i = 0; i < parts; i++)
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index b14366eac2185..a58fbd6deffa5 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2520,6 +2520,70 @@ TEST(APIntTest, setAllBits) {
EXPECT_EQ(128u, i128.popcount());
}
+TEST(APIntTest, clearBits) {
+ APInt i32 = APInt::getAllOnes(32);
+ i32.clearBits(1, 3);
+ EXPECT_EQ(1u, i32.countr_one());
+ EXPECT_EQ(0u, i32.countr_zero());
+ EXPECT_EQ(32u, i32.getActiveBits());
+ EXPECT_EQ(0u, i32.countl_zero());
+ EXPECT_EQ(29u, i32.countl_one());
+ EXPECT_EQ(30u, i32.popcount());
+
+ i32.clearBits(15, 15);
+ EXPECT_EQ(1u, i32.countr_one());
+ EXPECT_EQ(0u, i32.countr_zero());
+ EXPECT_EQ(32u, i32.getActiveBits());
+ EXPECT_EQ(0u, i32.countl_zero());
+ EXPECT_EQ(29u, i32.countl_one());
+ EXPECT_EQ(30u, i32.popcount());
+
+ i32.clearBits(28, 31);
+ EXPECT_EQ(1u, i32.countr_one());
+ EXPECT_EQ(0u, i32.countr_zero());
+ EXPECT_EQ(32u, i32.getActiveBits());
+ EXPECT_EQ(0u, i32.countl_zero());
+ EXPECT_EQ(1u, i32.countl_one());
+ EXPECT_EQ(27u, i32.popcount());
+ EXPECT_EQ(APInt(32, "8FFFFFF9", 16), i32);
+
+ APInt i256 = APInt::getAllOnes(256);
+ i256.clearBits(10, 250);
+ EXPECT_EQ(10u, i256.countr_one());
+ EXPECT_EQ(0u, i256.countr_zero());
+ EXPECT_EQ(256u, i256.getActiveBits());
+ EXPECT_EQ(0u, i256.countl_zero());
+ EXPECT_EQ(6u, i256.countl_one());
+ EXPECT_EQ(16u, i256.popcount());
+
+ APInt i311 = APInt::getAllOnes(311);
+ i311.clearBits(33, 99);
+ EXPECT_EQ(33u, i311.countr_one());
+ EXPECT_EQ(0u, i311.countr_zero());
+ EXPECT_EQ(311u, i311.getActiveBits());
+ EXPECT_EQ(0u, i311.countl_zero());
+ EXPECT_EQ(212u, i311.countl_one());
+ EXPECT_EQ(245u, i311.popcount());
+
+ APInt i64hi32 = APInt::getAllOnes(64);
+ i64hi32.clearBits(0, 32);
+ EXPECT_EQ(32u, i64hi32.countl_one());
+ EXPECT_EQ(0u, i64hi32.countl_zero());
+ EXPECT_EQ(64u, i64hi32.getActiveBits());
+ EXPECT_EQ(32u, i64hi32.countr_zero());
+ EXPECT_EQ(0u, i64hi32.countr_one());
+ EXPECT_EQ(32u, i64hi32.popcount());
+
+ i64hi32 = APInt::getAllOnes(64);
+ i64hi32.clearBits(32, 64);
+ EXPECT_EQ(32u, i64hi32.countr_one());
+ EXPECT_EQ(0u, i64hi32.countr_zero());
+ EXPECT_EQ(32u, i64hi32.getActiveBits());
+ EXPECT_EQ(32u, i64hi32.countl_zero());
+ EXPECT_EQ(0u, i64hi32.countl_one());
+ EXPECT_EQ(32u, i64hi32.popcount());
+}
+
TEST(APIntTest, getLoBits) {
APInt i32(32, 0xfa);
i32.setHighBits(1);
More information about the llvm-commits
mailing list