[llvm] a54102a - [llvm] Add support for zero-width integers in MathExtras.h (#87193)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 11:54:26 PDT 2024


Author: Théo Degioanni
Date: 2024-04-22T20:54:22+02:00
New Revision: a54102a093190f3a29add5d9327e62f13fce896a

URL: https://github.com/llvm/llvm-project/commit/a54102a093190f3a29add5d9327e62f13fce896a
DIFF: https://github.com/llvm/llvm-project/commit/a54102a093190f3a29add5d9327e62f13fce896a.diff

LOG: [llvm] Add support for zero-width integers in MathExtras.h (#87193)

MLIR uses zero-width integers, but also re-uses integer logic from LLVM
to avoid duplication. This creates issues when LLVM logic is used in
MLIR on integers which can be zero-width. In order to avoid
special-casing the bitwidth-related logic in MLIR, this PR adds support
for zero-width integers in LLVM's MathExtras (and consequently APInt).

While most of the logic in theory works the same way out of the box,
because bitshifting right by the entire bitwidth in C++ is undefined
behavior instead of being zero, some special cases had to be added.
Fortunately, it seems like the performance penalty is small. In x86,
this usually yields the addition of a predicated conditional move. I
checked that no branch is inserted in Arm too.

This happens to fix a crash in `arith.extsi` canonicalization in MLIR. I
think a follow-up PR to add tests for i0 in arith would be beneficial.

Added: 
    

Modified: 
    llvm/include/llvm/Support/MathExtras.h
    llvm/unittests/ADT/APIntTest.cpp
    llvm/unittests/Support/MathExtrasTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index aa4f4d2ed42e26..f0e4ee534ece38 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -66,7 +66,9 @@ template <typename T> T maskTrailingOnes(unsigned N) {
   static_assert(std::is_unsigned_v<T>, "Invalid type!");
   const unsigned Bits = CHAR_BIT * sizeof(T);
   assert(N <= Bits && "Invalid bit index");
-  return N == 0 ? 0 : (T(-1) >> (Bits - N));
+  if (N == 0)
+    return 0;
+  return T(-1) >> (Bits - N);
 }
 
 /// Create a bitmask with the N left-most bits set to 1, and all other
@@ -149,6 +151,8 @@ constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) {
 
 /// Checks if an integer fits into the given bit width.
 template <unsigned N> constexpr inline bool isInt(int64_t x) {
+  if constexpr (N == 0)
+    return 0 == x;
   if constexpr (N == 8)
     return static_cast<int8_t>(x) == x;
   if constexpr (N == 16)
@@ -164,15 +168,15 @@ template <unsigned N> constexpr inline bool isInt(int64_t x) {
 /// Checks if a signed integer is an N bit number shifted left by S.
 template <unsigned N, unsigned S>
 constexpr inline bool isShiftedInt(int64_t x) {
-  static_assert(
-      N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number.");
+  static_assert(S < 64, "isShiftedInt<N, S> with S >= 64 is too much.");
   static_assert(N + S <= 64, "isShiftedInt<N, S> with N + S > 64 is too wide.");
   return isInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
 }
 
 /// Checks if an unsigned integer fits into the given bit width.
 template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
-  static_assert(N > 0, "isUInt<0> doesn't make sense");
+  if constexpr (N == 0)
+    return 0 == x;
   if constexpr (N == 8)
     return static_cast<uint8_t>(x) == x;
   if constexpr (N == 16)
@@ -188,39 +192,46 @@ template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
 /// Checks if a unsigned integer is an N bit number shifted left by S.
 template <unsigned N, unsigned S>
 constexpr inline bool isShiftedUInt(uint64_t x) {
-  static_assert(
-      N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)");
+  static_assert(S < 64, "isShiftedUInt<N, S> with S >= 64 is too much.");
   static_assert(N + S <= 64,
                 "isShiftedUInt<N, S> with N + S > 64 is too wide.");
-  // Per the two static_asserts above, S must be strictly less than 64.  So
-  // 1 << S is not undefined behavior.
+  // S must be strictly less than 64. So 1 << S is not undefined behavior.
   return isUInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
 }
 
 /// Gets the maximum value for a N-bit unsigned integer.
 inline uint64_t maxUIntN(uint64_t N) {
-  assert(N > 0 && N <= 64 && "integer width out of range");
+  assert(N <= 64 && "integer width out of range");
 
   // uint64_t(1) << 64 is undefined behavior, so we can't do
   //   (uint64_t(1) << N) - 1
   // without checking first that N != 64.  But this works and doesn't have a
-  // branch.
+  // branch for N != 0.
+  // Unfortunately, shifting a uint64_t right by 64 bit is undefined
+  // behavior, so the condition on N == 0 is necessary. Fortunately, most
+  // optimizers do not emit branches for this check.
+  if (N == 0)
+    return 0;
   return UINT64_MAX >> (64 - N);
 }
 
 /// Gets the minimum value for a N-bit signed integer.
 inline int64_t minIntN(int64_t N) {
-  assert(N > 0 && N <= 64 && "integer width out of range");
+  assert(N <= 64 && "integer width out of range");
 
+  if (N == 0)
+    return 0;
   return UINT64_C(1) + ~(UINT64_C(1) << (N - 1));
 }
 
 /// Gets the maximum value for a N-bit signed integer.
 inline int64_t maxIntN(int64_t N) {
-  assert(N > 0 && N <= 64 && "integer width out of range");
+  assert(N <= 64 && "integer width out of range");
 
   // This relies on two's complement wraparound when N == 64, so we convert to
   // int64_t only at the very end to avoid UB.
+  if (N == 0)
+    return 0;
   return (UINT64_C(1) << (N - 1)) - 1;
 }
 
@@ -432,34 +443,38 @@ inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
-/// Requires 0 < B <= 32.
+/// Requires B <= 32.
 template <unsigned B> constexpr inline int32_t SignExtend32(uint32_t X) {
-  static_assert(B > 0, "Bit width can't be 0.");
   static_assert(B <= 32, "Bit width out of range.");
+  if constexpr (B == 0)
+    return 0;
   return int32_t(X << (32 - B)) >> (32 - B);
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
-/// Requires 0 < B <= 32.
+/// Requires B <= 32.
 inline int32_t SignExtend32(uint32_t X, unsigned B) {
-  assert(B > 0 && "Bit width can't be 0.");
   assert(B <= 32 && "Bit width out of range.");
+  if (B == 0)
+    return 0;
   return int32_t(X << (32 - B)) >> (32 - B);
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
-/// Requires 0 < B <= 64.
+/// Requires B <= 64.
 template <unsigned B> constexpr inline int64_t SignExtend64(uint64_t x) {
-  static_assert(B > 0, "Bit width can't be 0.");
   static_assert(B <= 64, "Bit width out of range.");
+  if constexpr (B == 0)
+    return 0;
   return int64_t(x << (64 - B)) >> (64 - B);
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
-/// Requires 0 < B <= 64.
+/// Requires B <= 64.
 inline int64_t SignExtend64(uint64_t X, unsigned B) {
-  assert(B > 0 && "Bit width can't be 0.");
   assert(B <= 64 && "Bit width out of range.");
+  if (B == 0)
+    return 0;
   return int64_t(X << (64 - B)) >> (64 - B);
 }
 
@@ -564,7 +579,6 @@ SaturatingMultiplyAdd(T X, T Y, T A, bool *ResultOverflowed = nullptr) {
 /// Use this rather than HUGE_VALF; the latter causes warnings on MSVC.
 extern const float huge_valf;
 
-
 /// Add two signed integers, computing the two's complement truncated result,
 /// returning true if overflow occurred.
 template <typename T>
@@ -644,6 +658,6 @@ std::enable_if_t<std::is_signed_v<T>, T> MulOverflow(T X, T Y, T &Result) {
     return UX > (static_cast<U>(std::numeric_limits<T>::max())) / UY;
 }
 
-} // End llvm namespace
+} // namespace llvm
 
 #endif

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 76fc26412407e7..46aaa47ee64503 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2797,6 +2797,9 @@ TEST(APIntTest, sext) {
   EXPECT_EQ(63U, i32_neg1.countl_one());
   EXPECT_EQ(0U, i32_neg1.countr_zero());
   EXPECT_EQ(63U, i32_neg1.popcount());
+
+  EXPECT_EQ(APInt(32u, 0), APInt(0u, 0).sext(32));
+  EXPECT_EQ(APInt(64u, 0), APInt(0u, 0).sext(64));
 }
 
 TEST(APIntTest, trunc) {

diff  --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp
index 72c765d9ba3004..218655851ca0d7 100644
--- a/llvm/unittests/Support/MathExtrasTest.cpp
+++ b/llvm/unittests/Support/MathExtrasTest.cpp
@@ -41,6 +41,9 @@ TEST(MathExtras, onesMask) {
 TEST(MathExtras, isIntN) {
   EXPECT_TRUE(isIntN(16, 32767));
   EXPECT_FALSE(isIntN(16, 32768));
+  EXPECT_TRUE(isUIntN(0, 0));
+  EXPECT_FALSE(isUIntN(0, 1));
+  EXPECT_FALSE(isUIntN(0, -1));
 }
 
 TEST(MathExtras, isUIntN) {
@@ -48,6 +51,8 @@ TEST(MathExtras, isUIntN) {
   EXPECT_FALSE(isUIntN(16, 65536));
   EXPECT_TRUE(isUIntN(1, 0));
   EXPECT_TRUE(isUIntN(6, 63));
+  EXPECT_TRUE(isUIntN(0, 0));
+  EXPECT_FALSE(isUIntN(0, 1));
 }
 
 TEST(MathExtras, maxIntN) {
@@ -55,6 +60,7 @@ TEST(MathExtras, maxIntN) {
   EXPECT_EQ(2147483647, maxIntN(32));
   EXPECT_EQ(std::numeric_limits<int32_t>::max(), maxIntN(32));
   EXPECT_EQ(std::numeric_limits<int64_t>::max(), maxIntN(64));
+  EXPECT_EQ(0, maxIntN(0));
 }
 
 TEST(MathExtras, minIntN) {
@@ -62,6 +68,7 @@ TEST(MathExtras, minIntN) {
   EXPECT_EQ(-64LL, minIntN(7));
   EXPECT_EQ(std::numeric_limits<int32_t>::min(), minIntN(32));
   EXPECT_EQ(std::numeric_limits<int64_t>::min(), minIntN(64));
+  EXPECT_EQ(0, minIntN(0));
 }
 
 TEST(MathExtras, maxUIntN) {
@@ -70,6 +77,7 @@ TEST(MathExtras, maxUIntN) {
   EXPECT_EQ(0xffffffffffffffffULL, maxUIntN(64));
   EXPECT_EQ(1ULL, maxUIntN(1));
   EXPECT_EQ(0x0fULL, maxUIntN(4));
+  EXPECT_EQ(0, maxUIntN(0));
 }
 
 TEST(MathExtras, reverseBits) {


        


More information about the llvm-commits mailing list