[Mlir-commits] [llvm] [mlir] MathExtras: avoid unnecessarily widening types (PR #95426)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 13 08:59:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

Several multi-argument functions unnecessarily widen types beyond the argument types. Template'ize the functions, and use std::common_type_t to avoid this, hence optimizing the functions. While at it, address usage issues raised in https://github.com/llvm/llvm-project/pull/95087. One of the requirements of this patch is to add overflow checks, and one caller in LoopVectorize is manually widened (marked as TODO).

-- 8<--
Based on #<!-- -->95425. I've tested the change in llvm and mlir, and am relying on the CI to test it in other projects.

---
Full diff: https://github.com/llvm/llvm-project/pull/95426.diff


6 Files Affected:

- (modified) llvm/include/llvm/Support/MathExtras.h (+89-25) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+2-1) 
- (modified) llvm/unittests/Support/MathExtrasTest.cpp (+53-2) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+1-1) 
- (modified) mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp (+2-3) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1) 


``````````diff
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 05d87e176dec1..1550a0e3f7458 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -23,6 +23,22 @@
 #include <type_traits>
 
 namespace llvm {
+/// Some template parameter helpers to optimize for bitwidth, for functions that
+/// take multiple arguments.
+
+// We can't verify signedness, since callers rely on implicit coercions to
+// signed/unsigned.
+template <typename T, typename U>
+using enableif_int =
+    std::enable_if_t<std::is_integral_v<T> && std::is_integral_v<U>>;
+
+// Use std::common_type_t to widen only up to the widest argument.
+template <typename T, typename U, typename = enableif_int<T, U>>
+using common_uint =
+    std::common_type_t<std::make_unsigned_t<T>, std::make_unsigned_t<U>>;
+template <typename T, typename U, typename = enableif_int<T, U>>
+using common_sint =
+    std::common_type_t<std::make_signed_t<T>, std::make_signed_t<U>>;
 
 /// Mathematical constants.
 namespace numbers {
@@ -346,7 +362,8 @@ inline unsigned Log2_64_Ceil(uint64_t Value) {
 
 /// A and B are either alignments or offsets. Return the minimum alignment that
 /// may be assumed after adding the two together.
-constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T MinAlign(U A, V B) {
   // The largest power of 2 that divides both A and B.
   //
   // Replace "-Value" by "1+~Value" in the following commented code to avoid
@@ -375,7 +392,7 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
   return UINT64_C(1) << Log2_64_Ceil(A);
 }
 
-/// Returns the next integer (mod 2**64) that is greater than or equal to
+/// Returns the next integer (mod 2**nbits) that is greater than or equal to
 /// \p Value and is a multiple of \p Align. \p Align must be non-zero.
 ///
 /// Examples:
@@ -385,18 +402,44 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
 ///   alignTo(~0LL, 8) = 0
 ///   alignTo(321, 255) = 510
 /// \endcode
-inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T alignTo(U Value, V Align) {
+  assert(Align != 0u && "Align can't be 0.");
+  // If Value is negative, wrap will occur in the cast.
+  if (Value > 0)
+    assert((T)Value <= std::numeric_limits<T>::max() - (Align - 1) &&
+           "alignTo would overflow");
+  return (Value + Align - 1) / Align * Align;
+}
+
+// Fallback when arguments aren't integral.
+constexpr inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
   assert(Align != 0u && "Align can't be 0.");
   return (Value + Align - 1) / Align * Align;
 }
 
-inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T alignToPowerOf2(U Value, V Align) {
   assert(Align != 0 && (Align & (Align - 1)) == 0 &&
          "Align must be a power of 2");
+  // If Value is negative, wrap will occur in the cast.
+  if (Value > 0)
+    assert((T)Value <= std::numeric_limits<T>::max() - (Align - 1) &&
+           "alignToPowerOf2 would overflow");
   // Replace unary minus to avoid compilation error on Windows:
   // "unary minus operator applied to unsigned type, result still unsigned"
-  uint64_t negAlign = (~Align) + 1;
-  return (Value + Align - 1) & negAlign;
+  T NegAlign = (~Align) + 1;
+  return (Value + Align - 1) & NegAlign;
+}
+
+// Fallback when arguments aren't integral.
+constexpr inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
+  assert(Align != 0 && (Align & (Align - 1)) == 0 &&
+         "Align must be a power of 2");
+  // Replace unary minus to avoid compilation error on Windows:
+  // "unary minus operator applied to unsigned type, result still unsigned"
+  uint64_t NegAlign = (~Align) + 1;
+  return (Value + Align - 1) & NegAlign;
 }
 
 /// If non-zero \p Skew is specified, the return value will be a minimal integer
@@ -411,7 +454,9 @@ inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
 ///   alignTo(~0LL, 8, 3) = 3
 ///   alignTo(321, 255, 42) = 552
 /// \endcode
-inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
+template <typename U, typename V, typename W,
+          typename T = common_uint<common_uint<U, V>, W>>
+constexpr T alignTo(U Value, V Align, W Skew) {
   assert(Align != 0u && "Align can't be 0.");
   Skew %= Align;
   return alignTo(Value - Skew, Align) + Skew;
@@ -419,56 +464,75 @@ inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
 
 /// Returns the next integer (mod 2**64) that is greater than or equal to
 /// \p Value and is a multiple of \c Align. \c Align must be non-zero.
-template <uint64_t Align> constexpr inline uint64_t alignTo(uint64_t Value) {
+template <uint64_t Align> constexpr uint64_t alignTo(uint64_t Value) {
   static_assert(Align != 0u, "Align must be non-zero");
   return (Value + Align - 1) / Align * Align;
 }
 
-/// Returns the integer ceil(Numerator / Denominator). Unsigned integer version.
-inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
+/// Returns the integer ceil(Numerator / Denominator). Unsigned version.
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T divideCeil(U Numerator, V Denominator) {
   return alignTo(Numerator, Denominator) / Denominator;
 }
 
-/// Returns the integer ceil(Numerator / Denominator). Signed integer version.
-inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
+// Fallback when arguments aren't integral.
+constexpr inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
+  return alignTo(Numerator, Denominator) / Denominator;
+}
+
+/// Returns the integer ceil(Numerator / Denominator). Signed version.
+/// Guaranteed to never overflow.
+template <typename U, typename V, typename T = common_sint<U, V>>
+constexpr T divideCeilSigned(U Numerator, V Denominator) {
   assert(Denominator && "Division by zero");
   if (!Numerator)
     return 0;
   // C's integer division rounds towards 0.
-  int64_t X = (Denominator > 0) ? -1 : 1;
+  T X = (Denominator > 0) ? -1 : 1;
   bool SameSign = (Numerator > 0) == (Denominator > 0);
   return SameSign ? ((Numerator + X) / Denominator) + 1
                   : Numerator / Denominator;
 }
 
-/// Returns the integer floor(Numerator / Denominator). Signed integer version.
-inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
+/// Returns the integer floor(Numerator / Denominator). Signed version.
+/// Guaranteed to never overflow.
+template <typename U, typename V, typename T = common_sint<U, V>>
+constexpr T divideFloorSigned(U Numerator, V Denominator) {
   assert(Denominator && "Division by zero");
   if (!Numerator)
     return 0;
   // C's integer division rounds towards 0.
-  int64_t X = (Denominator > 0) ? -1 : 1;
+  T X = (Denominator > 0) ? -1 : 1;
   bool SameSign = (Numerator > 0) == (Denominator > 0);
   return SameSign ? Numerator / Denominator
                   : -((-Numerator + X) / Denominator) - 1;
 }
 
 /// Returns the remainder of the Euclidean division of LHS by RHS. Result is
-/// always non-negative.
-inline int64_t mod(int64_t Numerator, int64_t Denominator) {
+/// always non-negative. Signed version. Guaranteed to never overflow.
+template <typename U, typename V, typename T = common_sint<U, V>>
+constexpr T mod(U Numerator, V Denominator) {
   assert(Denominator >= 1 && "Mod by non-positive number");
-  int64_t Mod = Numerator % Denominator;
+  T Mod = Numerator % Denominator;
   return Mod < 0 ? Mod + Denominator : Mod;
 }
 
 /// Returns the integer nearest(Numerator / Denominator).
-inline uint64_t divideNearest(uint64_t Numerator, uint64_t Denominator) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T divideNearest(U Numerator, V Denominator) {
+  // If Value is negative, wrap will occur in the cast.
+  if (Numerator > 0)
+    assert((T)Numerator <= std::numeric_limits<T>::max() - (Denominator / 2) &&
+           "divideNearest would overflow");
   return (Numerator + (Denominator / 2)) / Denominator;
 }
 
-/// Returns the largest uint64_t less than or equal to \p Value and is
-/// \p Skew mod \p Align. \p Align must be non-zero
-inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
+/// Returns the largest unsigned integer less than or equal to \p Value and is
+/// \p Skew mod \p Align. \p Align must be non-zero. Guaranteed to never
+/// overflow.
+template <typename U, typename V, typename W = uint8_t,
+          typename T = common_uint<common_uint<U, V>, W>>
+constexpr T alignDown(U Value, V Align, W Skew = 0) {
   assert(Align != 0u && "Align can't be 0.");
   Skew %= Align;
   return (Value - Skew) / Align * Align + Skew;
@@ -512,8 +576,8 @@ inline int64_t SignExtend64(uint64_t X, unsigned B) {
 
 /// Subtract two unsigned integers, X and Y, of type T and return the absolute
 /// value of the result.
-template <typename T>
-std::enable_if_t<std::is_unsigned_v<T>, T> AbsoluteDifference(T X, T Y) {
+template <typename U, typename V, typename T = common_uint<U, V>>
+constexpr T AbsoluteDifference(U X, V Y) {
   return X > Y ? (X - Y) : (Y - X);
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 37b8023e1fcf2..7a50f6e292a95 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -4803,7 +4803,8 @@ bool LoopVectorizationPlanner::isMoreProfitable(
     // different VFs we can use this to compare the total loop-body cost
     // expected after vectorization.
     if (CM.foldTailByMasking())
-      return VectorCost * divideCeil(MaxTripCount, VF);
+      // TODO: divideCeil will overflow, unless MaxTripCount is cast.
+      return VectorCost * divideCeil((uint64_t)MaxTripCount, VF);
     return VectorCost * (MaxTripCount / VF) + ScalarCost * (MaxTripCount % VF);
   };
 
diff --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp
index e75700df67e69..42ad952637bbe 100644
--- a/llvm/unittests/Support/MathExtrasTest.cpp
+++ b/llvm/unittests/Support/MathExtrasTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Support/MathExtras.h"
 #include "gtest/gtest.h"
+#include <limits>
 
 using namespace llvm;
 
@@ -175,6 +176,7 @@ TEST(MathExtras, MinAlign) {
   EXPECT_EQ(2u, MinAlign(2, 4));
   EXPECT_EQ(1u, MinAlign(17, 64));
   EXPECT_EQ(256u, MinAlign(256, 512));
+  EXPECT_EQ(2u, MinAlign(0, 2));
 }
 
 TEST(MathExtras, NextPowerOf2) {
@@ -183,15 +185,38 @@ TEST(MathExtras, NextPowerOf2) {
   EXPECT_EQ(256u, NextPowerOf2(128));
 }
 
-TEST(MathExtras, alignTo) {
+TEST(MathExtras, AlignTo) {
   EXPECT_EQ(8u, alignTo(5, 8));
   EXPECT_EQ(24u, alignTo(17, 8));
   EXPECT_EQ(0u, alignTo(~0LL, 8));
+#ifndef NDEBUG
+  EXPECT_DEATH(alignTo(std::numeric_limits<uint32_t>::max(), 2),
+               "alignTo would overflow");
+#endif
 
   EXPECT_EQ(7u, alignTo(5, 8, 7));
   EXPECT_EQ(17u, alignTo(17, 8, 1));
   EXPECT_EQ(3u, alignTo(~0LL, 8, 3));
   EXPECT_EQ(552u, alignTo(321, 255, 42));
+  EXPECT_EQ(std::numeric_limits<uint32_t>::max(),
+            alignTo(std::numeric_limits<uint32_t>::max(), 2, 1));
+}
+
+TEST(MathExtras, AlignToPowerOf2) {
+  EXPECT_EQ(8u, alignToPowerOf2(5, 8));
+  EXPECT_EQ(24u, alignToPowerOf2(17, 8));
+  EXPECT_EQ(0u, alignToPowerOf2(~0LL, 8));
+#ifndef NDEBUG
+  EXPECT_DEATH(alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2),
+               "alignToPowerOf2 would overflow");
+#endif
+}
+
+TEST(MathExtras, AlignDown) {
+  EXPECT_EQ(0u, alignDown(5, 8));
+  EXPECT_EQ(16u, alignDown(17, 8));
+  EXPECT_EQ(std::numeric_limits<uint32_t>::max() - 1,
+            alignDown(std::numeric_limits<uint32_t>::max(), 2));
 }
 
 template <typename T> void SaturatingAddTestHelper() {
@@ -434,7 +459,25 @@ TEST(MathExtras, IsShiftedInt) {
   EXPECT_FALSE((isShiftedInt<6, 10>(int64_t(1) << 15)));
 }
 
-TEST(MathExtras, DivideCeilSigned) {
+TEST(MathExtras, DivideNearest) {
+  EXPECT_EQ(divideNearest(14, 3), 5u);
+  EXPECT_EQ(divideNearest(15, 3), 5u);
+  EXPECT_EQ(divideNearest(0, 3), 0u);
+#ifndef NDEBUG
+  EXPECT_DEATH(divideNearest(std::numeric_limits<uint32_t>::max(), 2),
+               "divideNearest would overflow");
+#endif
+}
+
+TEST(MathExtras, DivideCeil) {
+  EXPECT_EQ(divideCeil(14, 3), 5u);
+  EXPECT_EQ(divideCeil(15, 3), 5u);
+  EXPECT_EQ(divideCeil(0, 3), 0u);
+#ifndef NDEBUG
+  EXPECT_DEATH(divideCeil(std::numeric_limits<uint32_t>::max(), 2),
+               "alignTo would overflow");
+#endif
+
   EXPECT_EQ(divideCeilSigned(14, 3), 5);
   EXPECT_EQ(divideCeilSigned(15, 3), 5);
   EXPECT_EQ(divideCeilSigned(14, -3), -4);
@@ -443,6 +486,10 @@ TEST(MathExtras, DivideCeilSigned) {
   EXPECT_EQ(divideCeilSigned(-15, 3), -5);
   EXPECT_EQ(divideCeilSigned(0, 3), 0);
   EXPECT_EQ(divideCeilSigned(0, -3), 0);
+  EXPECT_EQ(divideCeilSigned(std::numeric_limits<int32_t>::max(), 2),
+            std::numeric_limits<int32_t>::max() / 2 + 1);
+  EXPECT_EQ(divideCeilSigned(std::numeric_limits<int32_t>::max(), -2),
+            std::numeric_limits<int32_t>::min() / 2 + 1);
 }
 
 TEST(MathExtras, DivideFloorSigned) {
@@ -454,6 +501,10 @@ TEST(MathExtras, DivideFloorSigned) {
   EXPECT_EQ(divideFloorSigned(-15, 3), -5);
   EXPECT_EQ(divideFloorSigned(0, 3), 0);
   EXPECT_EQ(divideFloorSigned(0, -3), 0);
+  EXPECT_EQ(divideFloorSigned(std::numeric_limits<int32_t>::max(), 2),
+            std::numeric_limits<int32_t>::max() / 2);
+  EXPECT_EQ(divideFloorSigned(std::numeric_limits<int32_t>::max(), -2),
+            std::numeric_limits<int32_t>::min() / 2);
 }
 
 TEST(MathExtras, Mod) {
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83157b60c590b..b27c9e81b3293 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -114,7 +114,7 @@ inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
     return ShapedType::kDynamic;
 
   assert(dimSize % shardCount == 0);
-  return llvm::divideCeilSigned(dimSize, shardCount);
+  return dimSize / shardCount;
 }
 
 // Get the size of an unsharded dimension.
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 1b2d0258130cb..19c3ba1f95020 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -365,7 +365,7 @@ void UnrankedMemRefDescriptor::computeSizes(
   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
   Value indexSize = createIndexAttrConstant(
       builder, loc, indexType,
-      llvm::divideCeilSigned(typeConverter.getIndexTypeBitwidth(), 8));
+      llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
 
   sizes.reserve(sizes.size() + values.size());
   for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
@@ -378,8 +378,7 @@ void UnrankedMemRefDescriptor::computeSizes(
     // to data layout) into the unranked descriptor.
     Value pointerSize = createIndexAttrConstant(
         builder, loc, indexType,
-        llvm::divideCeilSigned(typeConverter.getPointerBitwidth(addressSpace),
-                               8));
+        llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
     Value doublePointerSize =
         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 7ef5a77fcb42c..80167ee3935a5 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -971,7 +971,7 @@ struct MemorySpaceCastOpLowering
                                resultUnderlyingDesc, resultElemPtrType);
 
       int64_t bytesToSkip =
-          2 * llvm::divideCeilSigned(
+          2 * llvm::divideCeil(
                   getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
       Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
           loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));

``````````

</details>


https://github.com/llvm/llvm-project/pull/95426


More information about the Mlir-commits mailing list