[llvm] [LoopInterchange] Fix overflow in cost calculation (PR #111807)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 10 03:09:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Sjoerd Meijer (sjoerdmeijer)

<details>
<summary>Changes</summary>

If the iteration count is really large, e.g. UINT_MAX, then the cost calculation can overflows and trigger an assert. So saturate the cost to INT_MAX if this is the case (the cost value is kept in a signed integer).

This fixes #<!-- -->104761

---
Full diff: https://github.com/llvm/llvm-project/pull/111807.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/LoopCacheAnalysis.cpp (+12-2) 
- (added) llvm/test/Transforms/LoopInterchange/refcost-overflow.ll (+44) 


``````````diff
diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
index 7ca9f15ad5fca0..3e03b5ba268cff 100644
--- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
@@ -328,6 +328,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
       const SCEV *TripCount =
           computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
       Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
+      // For the multiplication result to fit, request a type twice as wide.
+      WiderType = WiderType->getExtendedType();
       RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType),
                               SE.getNoopOrZeroExtend(TripCount, WiderType));
     }
@@ -338,8 +340,11 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
   assert(RefCost && "Expecting a valid RefCost");
 
   // Attempt to fold RefCost into a constant.
+  // CacheCostTy is a signed integer, but the tripcount value can be large
+  // and may not fit, so saturate/limit the value to the maximum signed
+  // integer value.
   if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
-    return ConstantCost->getValue()->getZExtValue();
+    return (CacheCostTy)ConstantCost->getValue()->getLimitedValue(~0ULL >> 1);
 
   LLVM_DEBUG(dbgs().indent(4)
              << "RefCost is not a constant! Setting to RefCost=InvalidCost "
@@ -712,7 +717,12 @@ CacheCost::computeLoopCacheCost(const Loop &L,
   CacheCostTy LoopCost = 0;
   for (const ReferenceGroupTy &RG : RefGroups) {
     CacheCostTy RefGroupCost = computeRefGroupCacheCost(RG, L);
-    LoopCost += RefGroupCost * TripCountsProduct;
+
+    // Saturate the cost to INT MAX if the value can overflow.
+    if (RefGroupCost > (std::numeric_limits<int64_t>::max() / TripCountsProduct))
+      LoopCost = std::numeric_limits<int64_t>::max();
+    else
+      LoopCost += RefGroupCost * TripCountsProduct;
   }
 
   LLVM_DEBUG(dbgs().indent(2) << "Loop '" << L.getName()
diff --git a/llvm/test/Transforms/LoopInterchange/refcost-overflow.ll b/llvm/test/Transforms/LoopInterchange/refcost-overflow.ll
new file mode 100644
index 00000000000000..7a823b85ce32e7
--- /dev/null
+++ b/llvm/test/Transforms/LoopInterchange/refcost-overflow.ll
@@ -0,0 +1,44 @@
+; REQUIRES: asserts
+
+; RUN: opt < %s -passes=loop-interchange -S -debug 2>&1 | FileCheck %s
+
+; For a loop with a very large iteration count, make sure the cost
+; calculation does not overflow:
+;
+; void a(int b) {
+;   for (int c;; c += b)
+;     for (long d = 0; d < -3ULL; d += 2ULL)
+;       A[c][d][d] = 0;
+; }
+
+; CHECK:  Access is not consecutive: RefCost=922337203685477580700
+; CHECK:  Loop 'for.cond' has cost=9223372036854775807
+; CHECK:  TripCount=9223372036854775807
+; CHECK:  Access is not consecutive: RefCost=9223372036854775807
+
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32"
+
+ at A = local_unnamed_addr global [11 x [11 x [11 x i32]]] zeroinitializer, align 16
+
+define void @foo(i32 noundef %b) {
+entry:
+  %0 = sext i32 %b to i64
+  br label %for.cond
+
+for.cond:
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.cond.cleanup ], [ 0, %entry ]
+  br label %for.body
+
+for.cond.cleanup:
+  %indvars.iv.next = add nsw i64 %indvars.iv, %0
+  br label %for.cond
+
+for.body:
+  %d.010 = phi i64 [ 0, %for.cond ], [ %add, %for.body ]
+  %arrayidx3 = getelementptr inbounds [11 x [11 x [11 x i32]]], ptr @A, i64 0, i64 %indvars.iv, i64 %d.010, i64 %d.010
+  store i32 0, ptr %arrayidx3, align 4
+  %add = add nuw i64 %d.010, 2
+  %cmp = icmp ult i64 %d.010, -5
+  br i1 %cmp, label %for.body, label %for.cond.cleanup
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/111807


More information about the llvm-commits mailing list