[llvm] 704a395 - [APInt] Enable APInt to support zero bit integers.
Chris Lattner via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 9 22:44:01 PDT 2021
Author: Chris Lattner
Date: 2021-09-09T22:43:54-07:00
New Revision: 704a39569346401e96a6a3978ddc490dfa828ccc
URL: https://github.com/llvm/llvm-project/commit/704a39569346401e96a6a3978ddc490dfa828ccc
DIFF: https://github.com/llvm/llvm-project/commit/704a39569346401e96a6a3978ddc490dfa828ccc.diff
LOG: [APInt] Enable APInt to support zero bit integers.
Motivation: APInt not supporting zero bit values leads to
a lot of special cases in various bits of code, particularly
when using APInt as a bit vector (where you want to start with
zero bits and then concat on more. This is particularly
challenging in the CIRCT project, where the absence of zero-bit
ConstantOp forces duplication of ops and makes instcombine-like
logic far more complicated.
Approach: zero bit integers are weird. There are two reasonable
approaches: either make it illegal to do general arithmetic on
them (e.g. sign extends), or treat them as as implicitly having
a zero value. This patch takes the conservative approach, which
enables their use in bitvector applications.
Differential Revision: https://reviews.llvm.org/D109555
Added:
Modified:
llvm/include/llvm/ADT/APInt.h
llvm/lib/Support/APInt.cpp
llvm/unittests/ADT/APIntTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 4629250d5f344..4aad2c950f744 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -66,6 +66,11 @@ inline APInt operator-(APInt);
/// not.
/// * In general, the class tries to follow the style of computation that LLVM
/// uses in its IR. This simplifies its use for LLVM.
+/// * APInt supports zero-bit-width values, but operations that require bits
+/// are not defined on it (e.g. you cannot ask for the sign of a zero-bit
+/// integer). This means that operations like zero extension and logical
+/// shifts are defined, but sign extension and ashr is not. Zero bit values
+/// compare and hash equal to themselves, and countLeadingZeros returns 0.
///
class LLVM_NODISCARD APInt {
public:
@@ -102,7 +107,6 @@ class LLVM_NODISCARD APInt {
/// \param isSigned how to treat signedness of val
APInt(unsigned numBits, uint64_t val, bool isSigned = false)
: BitWidth(numBits) {
- assert(BitWidth && "bitwidth too small");
if (isSingleWord()) {
U.VAL = val;
clearUnusedBits();
@@ -142,11 +146,7 @@ class LLVM_NODISCARD APInt {
/// \param radix the radix to use for the conversion
APInt(unsigned numBits, StringRef str, uint8_t radix);
- /// Default constructor that creates an uninteresting APInt
- /// representing a 1-bit zero value.
- ///
- /// This is useful for object deserialization (pair this with the static
- /// method Read).
+ /// Default constructor that creates an APInt with a 1-bit zero value.
explicit APInt() : BitWidth(1) { U.VAL = 0; }
/// Copy Constructor.
@@ -179,6 +179,9 @@ class LLVM_NODISCARD APInt {
/// NOTE: This is soft-deprecated. Please use `getZero()` instead.
static APInt getNullValue(unsigned numBits) { return getZero(numBits); }
+ /// Return an APInt zero bits wide.
+ static APInt getZeroWidth() { return getZero(0); }
+
/// Gets maximum unsigned value of APInt for specific bit width.
static APInt getMaxValue(unsigned numBits) {
return getAllOnesValue(numBits);
@@ -238,7 +241,6 @@ class LLVM_NODISCARD APInt {
///
/// \returns An APInt value with the requested bits set.
static APInt getBitsSet(unsigned numBits, unsigned loBit, unsigned hiBit) {
- assert(loBit <= hiBit && "loBit greater than hiBit");
APInt Res(numBits, 0);
Res.setBits(loBit, hiBit);
return Res;
@@ -257,8 +259,6 @@ class LLVM_NODISCARD APInt {
return Res;
}
- /// Get a value with upper bits starting at loBit set.
- ///
/// Constructs an APInt value that has a contiguous range of bits set. The
/// bits from loBit (inclusive) to numBits (exclusive) will be set. All other
/// bits will be zero. For example, with parameters(32, 12) you would get
@@ -274,8 +274,6 @@ class LLVM_NODISCARD APInt {
return Res;
}
- /// Get a value with high bits set
- ///
/// Constructs an APInt value that has the top hiBitsSet bits set.
///
/// \param numBits the bitwidth of the result
@@ -286,8 +284,6 @@ class LLVM_NODISCARD APInt {
return Res;
}
- /// Get a value with low bits set
- ///
/// Constructs an APInt value that has the bottom loBitsSet bits set.
///
/// \param numBits the bitwidth of the result
@@ -351,8 +347,11 @@ class LLVM_NODISCARD APInt {
/// Determine if all bits are set.
bool isAllOnes() const {
- if (isSingleWord())
+ if (isSingleWord()) {
+ if (BitWidth == 0)
+ return false;
return U.VAL == WORDTYPE_MAX >> (APINT_BITS_PER_WORD - BitWidth);
+ }
return countTrailingOnesSlowCase() == BitWidth;
}
@@ -360,7 +359,11 @@ class LLVM_NODISCARD APInt {
bool isAllOnesValue() const { return isAllOnes(); }
/// Determine if this value is zero, i.e. all bits are clear.
- bool isZero() const { return !*this; }
+ bool isZero() const {
+ if (isSingleWord())
+ return U.VAL == 0;
+ return countLeadingZerosSlowCase() == BitWidth;
+ }
/// NOTE: This is soft-deprecated. Please use `isZero()` instead.
bool isNullValue() const { return isZero(); }
@@ -388,8 +391,10 @@ class LLVM_NODISCARD APInt {
/// This checks to see if the value of this APInt is the maximum signed
/// value for the APInt's bit width.
bool isMaxSignedValue() const {
- if (isSingleWord())
+ if (isSingleWord()) {
+ assert(BitWidth && "zero width values not allowed");
return U.VAL == ((WordType(1) << (BitWidth - 1)) - 1);
+ }
return !isNegative() && countTrailingOnesSlowCase() == BitWidth - 1;
}
@@ -404,29 +409,27 @@ class LLVM_NODISCARD APInt {
/// This checks to see if the value of this APInt is the minimum signed
/// value for the APInt's bit width.
bool isMinSignedValue() const {
- if (isSingleWord())
+ if (isSingleWord()) {
+ assert(BitWidth && "zero width values not allowed");
return U.VAL == (WordType(1) << (BitWidth - 1));
+ }
return isNegative() && countTrailingZerosSlowCase() == BitWidth - 1;
}
/// Check if this APInt has an N-bits unsigned integer value.
- bool isIntN(unsigned N) const {
- assert(N && "0 bit APInt not supported");
- return getActiveBits() <= N;
- }
+ bool isIntN(unsigned N) const { return getActiveBits() <= N; }
/// Check if this APInt has an N-bits signed integer value.
- bool isSignedIntN(unsigned N) const {
- assert(N && "0 bit APInt not supported");
- return getMinSignedBits() <= N;
- }
+ bool isSignedIntN(unsigned N) const { return getMinSignedBits() <= N; }
/// Check if this APInt's value is a power of two greater than zero.
///
/// \returns true if the argument APInt value is a power of two > 0.
bool isPowerOf2() const {
- if (isSingleWord())
+ if (isSingleWord()) {
+ assert(BitWidth && "zero width values not allowed");
return isPowerOf2_64(U.VAL);
+ }
return countPopulationSlowCase() == 1;
}
@@ -438,7 +441,7 @@ class LLVM_NODISCARD APInt {
/// Convert APInt to a boolean value.
///
/// This converts the APInt to a boolean value as a test against zero.
- bool getBoolValue() const { return !!*this; }
+ bool getBoolValue() const { return !isZero(); }
/// If this value is smaller than the specified limit, return it, otherwise
/// return the limit value. This causes the value to saturate to the limit.
@@ -487,16 +490,16 @@ class LLVM_NODISCARD APInt {
/// Compute an APInt containing numBits highbits from this APInt.
///
- /// Get an APInt with the same BitWidth as this APInt, just zero mask
- /// the low bits and right shift to the least significant bit.
+ /// Get an APInt with the same BitWidth as this APInt, just zero mask the low
+ /// bits and right shift to the least significant bit.
///
/// \returns the high "numBits" bits of this APInt.
APInt getHiBits(unsigned numBits) const;
/// Compute an APInt containing numBits lowbits from this APInt.
///
- /// Get an APInt with the same BitWidth as this APInt, just zero mask
- /// the high bits.
+ /// Get an APInt with the same BitWidth as this APInt, just zero mask the high
+ /// bits.
///
/// \returns the low "numBits" bits of this APInt.
APInt getLoBits(unsigned numBits) const;
@@ -529,9 +532,7 @@ class LLVM_NODISCARD APInt {
/// \name Unary Operators
/// @{
- /// Postfix increment operator.
- ///
- /// Increments *this by 1.
+ /// Postfix increment operator. Increment *this by 1.
///
/// \returns a new APInt value representing the original value of *this.
APInt operator++(int) {
@@ -545,9 +546,7 @@ class LLVM_NODISCARD APInt {
/// \returns *this incremented by one
APInt &operator++();
- /// Postfix decrement operator.
- ///
- /// Decrements *this by 1.
+ /// Postfix decrement operator. Decrement *this by 1.
///
/// \returns a new APInt value representing the original value of *this.
APInt operator--(int) {
@@ -561,16 +560,9 @@ class LLVM_NODISCARD APInt {
/// \returns *this decremented by one.
APInt &operator--();
- /// Logical negation operator.
- ///
- /// Performs logical negation operation on this APInt.
- ///
- /// \returns true if *this is zero, false otherwise.
- bool operator!() const {
- if (isSingleWord())
- return U.VAL == 0;
- return countLeadingZerosSlowCase() == BitWidth;
- }
+ /// Logical negation operation on this APInt returns true if zero, like normal
+ /// integers.
+ bool operator!() const { return isZero(); }
/// @}
/// \name Assignment Operators
@@ -580,11 +572,12 @@ class LLVM_NODISCARD APInt {
///
/// \returns *this after assignment of RHS.
APInt &operator=(const APInt &RHS) {
- // If the bitwidths are the same, we can avoid mucking with memory
+ // The common case (both source or dest being inline) doesn't require
+ // allocation or deallocation.
if (isSingleWord() && RHS.isSingleWord()) {
U.VAL = RHS.U.VAL;
BitWidth = RHS.BitWidth;
- return clearUnusedBits();
+ return *this;
}
AssignSlowCase(RHS);
@@ -608,7 +601,6 @@ class LLVM_NODISCARD APInt {
BitWidth = that.BitWidth;
that.BitWidth = 0;
-
return *this;
}
@@ -1264,8 +1256,6 @@ class LLVM_NODISCARD APInt {
clearUnusedBits();
}
- /// Set a given bit to 1.
- ///
/// Set the given bit to 1 whose position is given as "bitPosition".
void setBit(unsigned BitPosition) {
assert(BitPosition < BitWidth && "BitPosition out of range");
@@ -1449,8 +1439,10 @@ class LLVM_NODISCARD APInt {
/// uint64_t. The bitwidth must be <= 64 or the value must fit within a
/// uint64_t. Otherwise an assertion will result.
uint64_t getZExtValue() const {
- if (isSingleWord())
+ if (isSingleWord()) {
+ assert(BitWidth && "zero width values not allowed");
return U.VAL;
+ }
assert(getActiveBits() <= 64 && "Too many bits for uint64_t");
return U.pVal[0];
}
@@ -1498,8 +1490,11 @@ class LLVM_NODISCARD APInt {
/// \returns 0 if the high order bit is not set, otherwise returns the number
/// of 1 bits from the most significant to the least
unsigned countLeadingOnes() const {
- if (isSingleWord())
+ if (isSingleWord()) {
+ if (BitWidth == 0)
+ return 0;
return llvm::countLeadingOnes(U.VAL << (APINT_BITS_PER_WORD - BitWidth));
+ }
return countLeadingOnesSlowCase();
}
@@ -1807,10 +1802,9 @@ class LLVM_NODISCARD APInt {
friend class APSInt;
- /// Fast internal constructor
- ///
/// This constructor is used only internally for speed of construction of
- /// temporaries. It is unsafe for general use so it is not public.
+ /// temporaries. It is unsafe since it takes ownership of the pointer, so it
+ /// is not public.
APInt(uint64_t *val, unsigned bits) : BitWidth(bits) { U.pVal = val; }
/// Determine which word a bit is in.
@@ -1820,10 +1814,7 @@ class LLVM_NODISCARD APInt {
return bitPosition / APINT_BITS_PER_WORD;
}
- /// Determine which bit in a word a bit is in.
- ///
- /// \returns the bit position in a word for the specified bit position
- /// in the APInt.
+ /// Determine which bit in a word the specified bit position is in.
static unsigned whichBit(unsigned bitPosition) {
return bitPosition % APINT_BITS_PER_WORD;
}
@@ -1845,11 +1836,14 @@ class LLVM_NODISCARD APInt {
/// significant word is assigned a value to ensure that those bits are
/// zero'd out.
APInt &clearUnusedBits() {
- // Compute how many bits are used in the final word
+ // Compute how many bits are used in the final word.
unsigned WordBits = ((BitWidth - 1) % APINT_BITS_PER_WORD) + 1;
// Mask out the high bits.
uint64_t mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - WordBits);
+ if (BitWidth == 0)
+ mask = 0;
+
if (isSingleWord())
U.VAL &= mask;
else
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 69787d757912f..39824905434cc 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -89,7 +89,6 @@ void APInt::initSlowCase(const APInt& that) {
}
void APInt::initFromArray(ArrayRef<uint64_t> bigVal) {
- assert(BitWidth && "Bitwidth too small");
assert(bigVal.data() && "Null pointer detected!");
if (isSingleWord())
U.VAL = bigVal[0];
@@ -105,19 +104,17 @@ void APInt::initFromArray(ArrayRef<uint64_t> bigVal) {
clearUnusedBits();
}
-APInt::APInt(unsigned numBits, ArrayRef<uint64_t> bigVal)
- : BitWidth(numBits) {
+APInt::APInt(unsigned numBits, ArrayRef<uint64_t> bigVal) : BitWidth(numBits) {
initFromArray(bigVal);
}
APInt::APInt(unsigned numBits, unsigned numWords, const uint64_t bigVal[])
- : BitWidth(numBits) {
+ : BitWidth(numBits) {
initFromArray(makeArrayRef(bigVal, numWords));
}
APInt::APInt(unsigned numbits, StringRef Str, uint8_t radix)
- : BitWidth(numbits) {
- assert(BitWidth && "Bitwidth too small");
+ : BitWidth(numbits) {
fromString(numbits, Str, radix);
}
@@ -233,9 +230,7 @@ APInt APInt::operator*(const APInt& RHS) const {
return APInt(BitWidth, U.VAL * RHS.U.VAL);
APInt Result(getMemory(getNumWords()), getBitWidth());
-
tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords());
-
Result.clearUnusedBits();
return Result;
}
@@ -258,8 +253,7 @@ void APInt::XorAssignSlowCase(const APInt &RHS) {
dst[i] ^= rhs[i];
}
-APInt& APInt::operator*=(const APInt& RHS) {
- assert(BitWidth == RHS.BitWidth && "Bit widths must be the same");
+APInt &APInt::operator*=(const APInt &RHS) {
*this = *this * RHS;
return *this;
}
@@ -714,6 +708,8 @@ APInt APInt::reverseBits() const {
return APInt(BitWidth, llvm::reverseBits<uint16_t>(U.VAL));
case 8:
return APInt(BitWidth, llvm::reverseBits<uint8_t>(U.VAL));
+ case 0:
+ return *this;
default:
break;
}
@@ -873,7 +869,6 @@ double APInt::roundToDouble(bool isSigned) const {
// Truncate to new width.
APInt APInt::trunc(unsigned width) const {
assert(width < BitWidth && "Invalid APInt Truncate request");
- assert(width && "Can't truncate to 0 bits");
if (width <= APINT_BITS_PER_WORD)
return APInt(width, getRawData()[0]);
@@ -896,7 +891,6 @@ APInt APInt::trunc(unsigned width) const {
// Truncate to new width with unsigned saturation.
APInt APInt::truncUSat(unsigned width) const {
assert(width < BitWidth && "Invalid APInt Truncate request");
- assert(width && "Can't truncate to 0 bits");
// Can we just losslessly truncate it?
if (isIntN(width))
@@ -908,7 +902,6 @@ APInt APInt::truncUSat(unsigned width) const {
// Truncate to new width with signed saturation.
APInt APInt::truncSSat(unsigned width) const {
assert(width < BitWidth && "Invalid APInt Truncate request");
- assert(width && "Can't truncate to 0 bits");
// Can we just losslessly truncate it?
if (isSignedIntN(width))
@@ -1071,6 +1064,8 @@ void APInt::shlSlowCase(unsigned ShiftAmt) {
// Calculate the rotate amount modulo the bit width.
static unsigned rotateModulo(unsigned BitWidth, const APInt &rotateAmt) {
+ if (BitWidth == 0)
+ return 0;
unsigned rotBitWidth = rotateAmt.getBitWidth();
APInt rot = rotateAmt;
if (rotBitWidth < BitWidth) {
@@ -1087,6 +1082,8 @@ APInt APInt::rotl(const APInt &rotateAmt) const {
}
APInt APInt::rotl(unsigned rotateAmt) const {
+ if (BitWidth == 0)
+ return *this;
rotateAmt %= BitWidth;
if (rotateAmt == 0)
return *this;
@@ -1098,6 +1095,8 @@ APInt APInt::rotr(const APInt &rotateAmt) const {
}
APInt APInt::rotr(unsigned rotateAmt) const {
+ if (BitWidth == 0)
+ return *this;
rotateAmt %= BitWidth;
if (rotateAmt == 0)
return *this;
@@ -2145,7 +2144,7 @@ void APInt::toString(SmallVectorImpl<char> &Str, unsigned Radix,
}
// First, check for a zero value and just short circuit the logic below.
- if (*this == 0) {
+ if (isZero()) {
while (*Prefix) {
Str.push_back(*Prefix);
++Prefix;
@@ -2713,7 +2712,7 @@ APInt llvm::APIntOps::RoundingUDiv(const APInt &A, const APInt &B,
case APInt::Rounding::UP: {
APInt Quo, Rem;
APInt::udivrem(A, B, Quo, Rem);
- if (Rem == 0)
+ if (Rem.isZero())
return Quo;
return Quo + 1;
}
@@ -2728,7 +2727,7 @@ APInt llvm::APIntOps::RoundingSDiv(const APInt &A, const APInt &B,
case APInt::Rounding::UP: {
APInt Quo, Rem;
APInt::sdivrem(A, B, Quo, Rem);
- if (Rem == 0)
+ if (Rem.isZero())
return Quo;
// This algorithm deals with arbitrary rounding mode used by sdivrem.
// We want to check whether the non-integer part of the mathematical value
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 74e0c469d7325..0f8f4626692b4 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -1422,7 +1422,6 @@ TEST(APIntTest, Log2) {
#ifdef GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(APIntTest, StringDeath) {
- EXPECT_DEATH((void)APInt(0, "", 0), "Bitwidth too small");
EXPECT_DEATH((void)APInt(32, "", 0), "Invalid string length");
EXPECT_DEATH((void)APInt(32, "0", 0), "Radix should be 2, 8, 10, 16, or 36!");
EXPECT_DEATH((void)APInt(32, "", 10), "Invalid string length");
@@ -2908,4 +2907,79 @@ TEST(APIntTest, SignbitZeroChecks) {
EXPECT_FALSE(APInt(8, 1).isNonPositive());
}
+TEST(APIntTest, ZeroWidth) {
+ // Zero width Constructors.
+ auto ZW = APInt::getZeroWidth();
+ EXPECT_EQ(0U, ZW.getBitWidth());
+ EXPECT_EQ(0U, APInt(0, ArrayRef<uint64_t>({0, 1, 2})).getBitWidth());
+ EXPECT_EQ(0U, APInt(0, "0", 10).getBitWidth());
+
+ // Default constructor is single bit wide.
+ EXPECT_EQ(1U, APInt().getBitWidth());
+
+ // Copy ctor (move is down below).
+ APInt ZW2(ZW);
+ EXPECT_EQ(0U, ZW2.getBitWidth());
+ // Assignment
+ ZW = ZW2;
+ EXPECT_EQ(0U, ZW.getBitWidth());
+
+ // Methods like getLowBitsSet work with zero bits.
+ EXPECT_EQ(0U, APInt::getLowBitsSet(0, 0).getBitWidth());
+ EXPECT_EQ(0U, APInt::getSplat(0, ZW).getBitWidth());
+
+ // Logical operators.
+ ZW |= ZW2;
+ ZW &= ZW2;
+ ZW ^= ZW2;
+ ZW |= 42; // These ignore high bits of the literal.
+ ZW &= 42;
+ ZW ^= 42;
+ EXPECT_EQ(1, ZW.isIntN(0));
+
+ // Modulo Arithmetic. Divide/Rem aren't defined on division by zero, so they
+ // aren't supported.
+ ZW += ZW2;
+ ZW -= ZW2;
+ ZW *= ZW2;
+
+ // Logical Shifts and rotates, the amount must be <= bitwidth.
+ ZW <<= 0;
+ ZW.lshrInPlace(0);
+ (void)ZW.rotl(0);
+ (void)ZW.rotr(0);
+
+ // Comparisons.
+ EXPECT_EQ(1, ZW == ZW);
+ EXPECT_EQ(0, ZW != ZW);
+ EXPECT_EQ(0, ZW.ult(ZW));
+
+ // Mutations.
+ ZW.setBitsWithWrap(0, 0);
+ ZW.setBits(0, 0);
+ ZW.clearAllBits();
+ ZW.flipAllBits();
+
+ // Leading, trailing, ctpop, etc
+ EXPECT_EQ(0U, ZW.countLeadingZeros());
+ EXPECT_EQ(0U, ZW.countLeadingOnes());
+ EXPECT_EQ(0U, ZW.countPopulation());
+ EXPECT_EQ(0U, ZW.reverseBits().getBitWidth());
+ EXPECT_EQ(0U, ZW.getHiBits(0).getBitWidth());
+ EXPECT_EQ(0U, ZW.getLoBits(0).getBitWidth());
+ EXPECT_EQ(0, ZW.zext(4));
+ EXPECT_EQ(0U, APInt(4, 3).trunc(0).getBitWidth());
+
+ SmallString<42> STR;
+ ZW.toStringUnsigned(STR);
+ EXPECT_EQ("0", STR);
+
+ // Move ctor (keep at the end of the method since moves are destructive).
+ APInt MZW1(std::move(ZW));
+ EXPECT_EQ(0U, MZW1.getBitWidth());
+ // Move Assignment
+ MZW1 = std::move(ZW2);
+ EXPECT_EQ(0U, MZW1.getBitWidth());
+}
+
} // end anonymous namespace
More information about the llvm-commits
mailing list