[llvm] [mlir] MathExtras: avoid unnecessarily widening types (PR #95426)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Sat Jun 15 09:21:56 PDT 2024
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/95426
>From fe0a9c4fd14036ae59d65322d245ce38a21c9914 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 12 Jun 2024 12:58:03 +0100
Subject: [PATCH 1/2] MathExtras: avoid unnecessarily widening types
Several multi-argument functions unnecessarily widen types beyond the
argument types. Template'ize the functions, and use std::common_type_t
to avoid this, hence optimizing the functions. While at it, address
usage issues raised in #95087. One of the requirements of this patch is
to add overflow checks, and one caller in LoopVectorize and one in
AMDGPUBaseInfo is manually widened.
---
llvm/include/llvm/Support/MathExtras.h | 129 ++++++++++++++----
llvm/unittests/Support/MathExtrasTest.cpp | 41 +++++-
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 +-
.../Conversion/LLVMCommon/MemRefBuilder.cpp | 5 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 2 +-
5 files changed, 143 insertions(+), 36 deletions(-)
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 5bcefe4b6c361..ed5ea8befe0ea 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -23,6 +23,22 @@
#include <type_traits>
namespace llvm {
+/// Some template parameter helpers to optimize for bitwidth, for functions that
+/// take multiple arguments.
+
+// We can't verify signedness, since callers rely on implicit coercions to
+// signed/unsigned.
+template <typename T, typename U>
+using enableif_int =
+ std::enable_if_t<std::is_integral_v<T> && std::is_integral_v<U>>;
+
+// Use std::common_type_t to widen only up to the widest argument.
+template <typename T, typename U, typename = enableif_int<T, U>>
+using common_uint =
+ std::common_type_t<std::make_unsigned_t<T>, std::make_unsigned_t<U>>;
+template <typename T, typename U, typename = enableif_int<T, U>>
+using common_sint =
+ std::common_type_t<std::make_signed_t<T>, std::make_signed_t<U>>;
/// Mathematical constants.
namespace numbers {
@@ -346,7 +362,8 @@ inline unsigned Log2_64_Ceil(uint64_t Value) {
/// A and B are either alignments or offsets. Return the minimum alignment that
/// may be assumed after adding the two together.
-constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T MinAlign(U A, V B) {
// The largest power of 2 that divides both A and B.
//
// Replace "-Value" by "1+~Value" in the following commented code to avoid
@@ -375,7 +392,7 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
return UINT64_C(1) << Log2_64_Ceil(A);
}
-/// Returns the next integer (mod 2**64) that is greater than or equal to
+/// Returns the next integer (mod 2**nbits) that is greater than or equal to
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
///
/// Examples:
@@ -386,19 +403,50 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
/// alignTo(321, 255) = 510
/// \endcode
///
-/// May overflow.
-inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
+/// Will overflow only if result is not representable.
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T alignTo(U Value, V Align) {
assert(Align != 0u && "Align can't be 0.");
- return (Value + Align - 1) / Align * Align;
+ T Bias = (Value != 0);
+ T CeilDiv = (Value - Bias) / Align + Bias;
+ // If Value is negative, wrap will occur in the cast.
+ if (Value > 0)
+ assert(CeilDiv <= (std::numeric_limits<T>::max() - 1) / Align &&
+ "alignTo would overflow");
+ return CeilDiv * Align;
}
-inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
+/// Fallback when arguments aren't integral.
+constexpr inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
+ assert(Align != 0u && "Align can't be 0.");
+ uint64_t Bias = (Value != 0);
+ uint64_t CeilDiv = (Value - Bias) / Align + Bias;
+ return CeilDiv * Align;
+}
+
+/// Will overflow only if result is not representable.
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T alignToPowerOf2(U Value, V Align) {
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
"Align must be a power of 2");
// Replace unary minus to avoid compilation error on Windows:
// "unary minus operator applied to unsigned type, result still unsigned"
- uint64_t negAlign = (~Align) + 1;
- return (Value + Align - 1) & negAlign;
+ T NegAlign = (~Align) + 1;
+ // If Value is negative, wrap will occur in the cast.
+ if (Value > 0)
+ assert(static_cast<T>(Value) - 1 <= std::numeric_limits<T>::max() - Align &&
+ "alignToPowerOf2 would overflow");
+ return (Value - 1 + Align) & NegAlign;
+}
+
+/// Fallback when arguments aren't integral.
+constexpr inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
+ assert(Align != 0 && (Align & (Align - 1)) == 0 &&
+ "Align must be a power of 2");
+ // Replace unary minus to avoid compilation error on Windows:
+ // "unary minus operator applied to unsigned type, result still unsigned"
+ uint64_t NegAlign = (~Align) + 1;
+ return (Value - 1 + Align) & NegAlign;
}
/// If non-zero \p Skew is specified, the return value will be a minimal integer
@@ -413,22 +461,41 @@ inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
/// alignTo(~0LL, 8, 3) = 3
/// alignTo(321, 255, 42) = 552
/// \endcode
-inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
+template <typename U, typename V, typename W,
+ typename T = common_uint<common_uint<U, V>, W>>
+constexpr T alignTo(U Value, V Align, W Skew) {
assert(Align != 0u && "Align can't be 0.");
Skew %= Align;
return alignTo(Value - Skew, Align) + Skew;
}
-/// Returns the next integer (mod 2**64) that is greater than or equal to
+/// Returns the next integer (mod 2**nbits) that is greater than or equal to
/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
-template <uint64_t Align> constexpr inline uint64_t alignTo(uint64_t Value) {
+///
+/// Will overflow only if result is not representable.
+template <auto Align, typename V, typename T = common_uint<decltype(Align), V>>
+constexpr T alignTo(V Value) {
static_assert(Align != 0u, "Align must be non-zero");
- return (Value + Align - 1) / Align * Align;
+ T Bias = (Value != 0);
+ T CeilDiv = (Value - Bias) / Align + Bias;
+ // If Value is negative, wrap will occur in the cast.
+ if (Value > 0)
+ assert(CeilDiv <= (std::numeric_limits<T>::max() - 1) / Align &&
+ "alignTo would overflow");
+ return CeilDiv * Align;
}
/// Returns the integer ceil(Numerator / Denominator). Unsigned version.
/// Guaranteed to never overflow.
-inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T divideCeil(U Numerator, V Denominator) {
+ assert(Denominator && "Division by zero");
+ T Bias = (Numerator != 0);
+ return (Numerator - Bias) / Denominator + Bias;
+}
+
+/// Fallback when arguments aren't integral.
+constexpr inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
assert(Denominator && "Division by zero");
uint64_t Bias = (Numerator != 0);
return (Numerator - Bias) / Denominator + Bias;
@@ -436,12 +503,13 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
/// Returns the integer ceil(Numerator / Denominator). Signed version.
/// Guaranteed to never overflow.
-inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
+template <typename U, typename V, typename T = common_sint<U, V>>
+constexpr T divideCeilSigned(U Numerator, V Denominator) {
assert(Denominator && "Division by zero");
if (!Numerator)
return 0;
// C's integer division rounds towards 0.
- int64_t Bias = (Denominator >= 0 ? 1 : -1);
+ T Bias = Denominator >= 0 ? 1 : -1;
bool SameSign = (Numerator >= 0) == (Denominator >= 0);
return SameSign ? (Numerator - Bias) / Denominator + 1
: Numerator / Denominator;
@@ -449,12 +517,13 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
/// Returns the integer floor(Numerator / Denominator). Signed version.
/// Guaranteed to never overflow.
-inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
+template <typename U, typename V, typename T = common_sint<U, V>>
+constexpr T divideFloorSigned(U Numerator, V Denominator) {
assert(Denominator && "Division by zero");
if (!Numerator)
return 0;
// C's integer division rounds towards 0.
- int64_t Bias = Denominator >= 0 ? -1 : 1;
+ T Bias = Denominator >= 0 ? -1 : 1;
bool SameSign = (Numerator >= 0) == (Denominator >= 0);
return SameSign ? Numerator / Denominator
: (Numerator - Bias) / Denominator - 1;
@@ -462,23 +531,29 @@ inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
/// always non-negative.
-inline int64_t mod(int64_t Numerator, int64_t Denominator) {
+template <typename U, typename V, typename T = common_sint<U, V>>
+constexpr T mod(U Numerator, V Denominator) {
assert(Denominator >= 1 && "Mod by non-positive number");
- int64_t Mod = Numerator % Denominator;
+ T Mod = Numerator % Denominator;
return Mod < 0 ? Mod + Denominator : Mod;
}
/// Returns (Numerator / Denominator) rounded by round-half-up. Guaranteed to
/// never overflow.
-inline uint64_t divideNearest(uint64_t Numerator, uint64_t Denominator) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T divideNearest(U Numerator, V Denominator) {
assert(Denominator && "Division by zero");
- uint64_t Mod = Numerator % Denominator;
- return (Numerator / Denominator) + (Mod > (Denominator - 1) / 2);
+ T Mod = Numerator % Denominator;
+ return (Numerator / Denominator) +
+ (Mod > (static_cast<T>(Denominator) - 1) / 2);
}
-/// Returns the largest uint64_t less than or equal to \p Value and is
-/// \p Skew mod \p Align. \p Align must be non-zero
-inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
+/// Returns the largest unsigned integer less than or equal to \p Value and is
+/// \p Skew mod \p Align. \p Align must be non-zero. Guaranteed to never
+/// overflow.
+template <typename U, typename V, typename W = uint8_t,
+ typename T = common_uint<common_uint<U, V>, W>>
+constexpr T alignDown(U Value, V Align, W Skew = 0) {
assert(Align != 0u && "Align can't be 0.");
Skew %= Align;
return (Value - Skew) / Align * Align + Skew;
@@ -522,8 +597,8 @@ inline int64_t SignExtend64(uint64_t X, unsigned B) {
/// Subtract two unsigned integers, X and Y, of type T and return the absolute
/// value of the result.
-template <typename T>
-std::enable_if_t<std::is_unsigned_v<T>, T> AbsoluteDifference(T X, T Y) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T AbsoluteDifference(U X, V Y) {
return X > Y ? (X - Y) : (Y - X);
}
diff --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp
index bd09bab9be004..81e20de26aa33 100644
--- a/llvm/unittests/Support/MathExtrasTest.cpp
+++ b/llvm/unittests/Support/MathExtrasTest.cpp
@@ -189,8 +189,27 @@ TEST(MathExtras, AlignTo) {
EXPECT_EQ(8u, alignTo(5, 8));
EXPECT_EQ(24u, alignTo(17, 8));
EXPECT_EQ(0u, alignTo(~0LL, 8));
- EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
- alignTo(std::numeric_limits<uint32_t>::max(), 2));
+ EXPECT_EQ(8u, alignTo(5ULL, 8ULL));
+ EXPECT_EQ(254u,
+ alignTo(static_cast<uint8_t>(200), static_cast<uint8_t>(127)));
+#ifndef NDEBUG
+ EXPECT_DEATH(alignTo(static_cast<uint8_t>(200), static_cast<uint8_t>(128)),
+ "alignTo would overflow");
+ EXPECT_DEATH(alignTo(std::numeric_limits<uint32_t>::max(), 2),
+ "alignTo would overflow");
+#endif
+
+ EXPECT_EQ(8u, alignTo<8>(5));
+ EXPECT_EQ(24u, alignTo<8>(17));
+ EXPECT_EQ(0u, alignTo<8>(~0LL));
+ EXPECT_EQ(254u,
+ alignTo<static_cast<uint8_t>(127)>(static_cast<uint8_t>(200)));
+#ifndef NDEBUG
+ EXPECT_DEATH(alignTo<static_cast<uint8_t>(128)>(static_cast<uint8_t>(200)),
+ "alignTo would overflow");
+ EXPECT_DEATH(alignTo<2>(std::numeric_limits<uint32_t>::max()),
+ "alignTo would overflow");
+#endif
EXPECT_EQ(7u, alignTo(5, 8, 7));
EXPECT_EQ(17u, alignTo(17, 8, 1));
@@ -198,14 +217,27 @@ TEST(MathExtras, AlignTo) {
EXPECT_EQ(552u, alignTo(321, 255, 42));
EXPECT_EQ(std::numeric_limits<uint32_t>::max(),
alignTo(std::numeric_limits<uint32_t>::max(), 2, 1));
+
+#ifndef NDEBUG
+ EXPECT_DEATH(alignTo(std::numeric_limits<uint32_t>::max(), 4, 2),
+ "alignTo would overflow");
+#endif
}
TEST(MathExtras, AlignToPowerOf2) {
EXPECT_EQ(8u, alignToPowerOf2(5, 8));
EXPECT_EQ(24u, alignToPowerOf2(17, 8));
EXPECT_EQ(0u, alignToPowerOf2(~0LL, 8));
- EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
- alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2));
+ EXPECT_EQ(8u, alignToPowerOf2(5ULL, 8ULL));
+ EXPECT_EQ(240u, alignToPowerOf2(static_cast<uint8_t>(240),
+ static_cast<uint8_t>(16)));
+#ifndef NDEBUG
+ EXPECT_DEATH(
+ alignToPowerOf2(static_cast<uint8_t>(200), static_cast<uint8_t>(128)),
+ "alignToPowerOf2 would overflow");
+ EXPECT_DEATH(alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2),
+ "alignToPowerOf2 would overflow");
+#endif
}
TEST(MathExtras, AlignDown) {
@@ -484,6 +516,7 @@ TEST(MathExtras, DivideCeil) {
EXPECT_EQ(divideCeil(3, 1), 3u);
EXPECT_EQ(divideCeil(3, 6), 1u);
EXPECT_EQ(divideCeil(3, 7), 1u);
+ EXPECT_EQ(divideCeil(3ULL, 7ULL), 1u);
EXPECT_EQ(divideCeil(std::numeric_limits<uint32_t>::max(), 2),
std::numeric_limits<uint32_t>::max() / 2 + 1);
EXPECT_EQ(divideCeil(std::numeric_limits<uint64_t>::max(), 2),
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83157b60c590b..b27c9e81b3293 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -114,7 +114,7 @@ inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
return ShapedType::kDynamic;
assert(dimSize % shardCount == 0);
- return llvm::divideCeilSigned(dimSize, shardCount);
+ return dimSize / shardCount;
}
// Get the size of an unsharded dimension.
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 1b2d0258130cb..19c3ba1f95020 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -365,7 +365,7 @@ void UnrankedMemRefDescriptor::computeSizes(
Value two = createIndexAttrConstant(builder, loc, indexType, 2);
Value indexSize = createIndexAttrConstant(
builder, loc, indexType,
- llvm::divideCeilSigned(typeConverter.getIndexTypeBitwidth(), 8));
+ llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
sizes.reserve(sizes.size() + values.size());
for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
@@ -378,8 +378,7 @@ void UnrankedMemRefDescriptor::computeSizes(
// to data layout) into the unranked descriptor.
Value pointerSize = createIndexAttrConstant(
builder, loc, indexType,
- llvm::divideCeilSigned(typeConverter.getPointerBitwidth(addressSpace),
- 8));
+ llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
Value doublePointerSize =
builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 1933c1dfcfba4..054827d40f0f3 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -971,7 +971,7 @@ struct MemorySpaceCastOpLowering
resultUnderlyingDesc, resultElemPtrType);
int64_t bytesToSkip =
- 2 * llvm::divideCeilSigned(
+ 2 * llvm::divideCeil(
getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
>From 420f4d07ecc336cd0ba30a9e005548f81553cac1 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Sat, 15 Jun 2024 17:21:02 +0100
Subject: [PATCH 2/2] MathExtras: fix lld build; fallback for MinAlign
---
llvm/include/llvm/Support/MathExtras.h | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index ed5ea8befe0ea..de3503dc07470 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -372,6 +372,11 @@ constexpr T MinAlign(U A, V B) {
return (A | B) & (1 + ~(A | B));
}
+/// Fallback when arguments aren't integral.
+constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
+ return (A | B) & (1 + ~(A | B));
+}
+
/// Returns the next power of two (in 64-bits) that is strictly greater than A.
/// Returns zero on overflow.
constexpr inline uint64_t NextPowerOf2(uint64_t A) {
@@ -443,8 +448,6 @@ constexpr T alignToPowerOf2(U Value, V Align) {
constexpr inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
"Align must be a power of 2");
- // Replace unary minus to avoid compilation error on Windows:
- // "unary minus operator applied to unsigned type, result still unsigned"
uint64_t NegAlign = (~Align) + 1;
return (Value - 1 + Align) & NegAlign;
}
More information about the llvm-commits
mailing list