[llvm] 2e6deb1 - [LoopInterchange] Fix overflow in cost calculation (#111807)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 14 00:40:48 PST 2024


Author: Sjoerd Meijer
Date: 2024-11-14T08:40:45Z
New Revision: 2e6deb1dd3a4422807633ba08773e8d786e43d4c

URL: https://github.com/llvm/llvm-project/commit/2e6deb1dd3a4422807633ba08773e8d786e43d4c
DIFF: https://github.com/llvm/llvm-project/commit/2e6deb1dd3a4422807633ba08773e8d786e43d4c.diff

LOG: [LoopInterchange] Fix overflow in cost calculation (#111807)

If the iteration count is really large, e.g. UINT_MAX, then the cost
calculation can overflows and trigger an assert. So saturate the cost to
INT_MAX if this is the case by using InstructionCost as a type which
already supports this kind of overflow handling.

This fixes #104761

Added: 
    llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll

Modified: 
    llvm/include/llvm/Analysis/LoopCacheAnalysis.h
    llvm/lib/Analysis/LoopCacheAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/LoopCacheAnalysis.h b/llvm/include/llvm/Analysis/LoopCacheAnalysis.h
index 4fd2485e39d6db..3e22487e5e349c 100644
--- a/llvm/include/llvm/Analysis/LoopCacheAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopCacheAnalysis.h
@@ -16,6 +16,7 @@
 
 #include "llvm/Analysis/LoopAnalysisManager.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Support/InstructionCost.h"
 #include <optional>
 
 namespace llvm {
@@ -31,7 +32,7 @@ class ScalarEvolution;
 class SCEV;
 class TargetTransformInfo;
 
-using CacheCostTy = int64_t;
+using CacheCostTy = InstructionCost;
 using LoopVectorTy = SmallVector<Loop *, 8>;
 
 /// Represents a memory reference as a base pointer and a set of indexing
@@ -192,8 +193,6 @@ class CacheCost {
   using LoopCacheCostTy = std::pair<const Loop *, CacheCostTy>;
 
 public:
-  static CacheCostTy constexpr InvalidCost = -1;
-
   /// Construct a CacheCost object for the loop nest described by \p Loops.
   /// The optional parameter \p TRT can be used to specify the max. distance
   /// between array elements accessed in a loop so that the elements are

diff  --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
index 7ca9f15ad5fca0..2897b922f61e48 100644
--- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
@@ -328,6 +328,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
       const SCEV *TripCount =
           computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
       Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
+      // For the multiplication result to fit, request a type twice as wide.
+      WiderType = WiderType->getExtendedType();
       RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType),
                               SE.getNoopOrZeroExtend(TripCount, WiderType));
     }
@@ -338,14 +340,18 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
   assert(RefCost && "Expecting a valid RefCost");
 
   // Attempt to fold RefCost into a constant.
+  // CacheCostTy is a signed integer, but the tripcount value can be large
+  // and may not fit, so saturate/limit the value to the maximum signed
+  // integer value.
   if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
-    return ConstantCost->getValue()->getZExtValue();
+    return ConstantCost->getValue()->getLimitedValue(
+        std::numeric_limits<int64_t>::max());
 
   LLVM_DEBUG(dbgs().indent(4)
              << "RefCost is not a constant! Setting to RefCost=InvalidCost "
                 "(invalid value).\n");
 
-  return CacheCost::InvalidCost;
+  return CacheCostTy::getInvalid();
 }
 
 bool IndexedReference::tryDelinearizeFixedSize(
@@ -696,7 +702,7 @@ CacheCostTy
 CacheCost::computeLoopCacheCost(const Loop &L,
                                 const ReferenceGroupsTy &RefGroups) const {
   if (!L.isLoopSimplifyForm())
-    return InvalidCost;
+    return CacheCostTy::getInvalid();
 
   LLVM_DEBUG(dbgs() << "Considering loop '" << L.getName()
                     << "' as innermost loop.\n");

diff  --git a/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll b/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll
new file mode 100644
index 00000000000000..7b6529601da32d
--- /dev/null
+++ b/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll
@@ -0,0 +1,37 @@
+; RUN: opt <  %s  -passes='print<loop-cache-cost>' -disable-output 2>&1 | FileCheck  %s
+
+; For a loop with a very large iteration count, make sure the cost
+; calculation does not overflow:
+;
+; void a(int b) {
+;   for (int c;; c += b)
+;     for (long d = 0; d < -3ULL; d += 2ULL)
+;       A[c][d][d] = 0;
+; }
+
+; CHECK: Loop 'outer.loop' has cost = 9223372036854775807
+; CHECK: Loop 'inner.loop' has cost = 9223372036854775807
+
+ at A = local_unnamed_addr global [11 x [11 x [11 x i32]]] zeroinitializer, align 16
+
+define void @foo(i32 noundef %b) {
+entry:
+  %0 = sext i32 %b to i64
+  br label %outer.loop
+
+outer.loop:
+  %indvars.iv = phi i64 [ %indvars.iv.next, %outer.loop.cleanup ], [ 0, %entry ]
+  br label %inner.loop
+
+outer.loop.cleanup:
+  %indvars.iv.next = add nsw i64 %indvars.iv, %0
+  br label %outer.loop
+
+inner.loop:
+  %inner.iv = phi i64 [ 0, %outer.loop ], [ %add, %inner.loop ]
+  %arrayidx3 = getelementptr inbounds [11 x [11 x [11 x i32]]], ptr @A, i64 0, i64 %indvars.iv, i64 %inner.iv, i64 %inner.iv
+  store i32 0, ptr %arrayidx3, align 4
+  %add = add nuw i64 %inner.iv, 2
+  %cmp = icmp ult i64 %inner.iv, -5
+  br i1 %cmp, label %inner.loop, label %outer.loop.cleanup
+}


        


More information about the llvm-commits mailing list