[llvm] [LoopInterchange] Fix overflow in cost calculation (PR #111807)
Sjoerd Meijer via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 10 04:20:29 PDT 2024
https://github.com/sjoerdmeijer updated https://github.com/llvm/llvm-project/pull/111807
>From 3265392b2fe7963de9ff7a4691935033ec99298a Mon Sep 17 00:00:00 2001
From: Sjoerd Meijer <smeijer at nvidia.com>
Date: Thu, 10 Oct 2024 02:57:50 -0700
Subject: [PATCH] [LoopInterchange] Fix overflow in cost calculation
If the iteration count is really large, e.g. UINT_MAX, then the cost
calculation can overflows and trigger an assert. So saturate the cost to
INT_MAX if this is the case (the cost value is kept in a signed integer).
This fixes #104761
---
llvm/lib/Analysis/LoopCacheAnalysis.cpp | 16 +++++++-
.../interchange-refcost-overflow.ll | 37 +++++++++++++++++++
2 files changed, 51 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll
diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
index 7ca9f15ad5fca0..60a4b17216b743 100644
--- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
@@ -328,6 +328,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
const SCEV *TripCount =
computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
+ // For the multiplication result to fit, request a type twice as wide.
+ WiderType = WiderType->getExtendedType();
RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType),
SE.getNoopOrZeroExtend(TripCount, WiderType));
}
@@ -338,8 +340,12 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
assert(RefCost && "Expecting a valid RefCost");
// Attempt to fold RefCost into a constant.
+ // CacheCostTy is a signed integer, but the tripcount value can be large
+ // and may not fit, so saturate/limit the value to the maximum signed
+ // integer value.
if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
- return ConstantCost->getValue()->getZExtValue();
+ return (CacheCostTy)ConstantCost->getValue()->getLimitedValue(
+ std::numeric_limits<CacheCostTy>::max());
LLVM_DEBUG(dbgs().indent(4)
<< "RefCost is not a constant! Setting to RefCost=InvalidCost "
@@ -712,7 +718,13 @@ CacheCost::computeLoopCacheCost(const Loop &L,
CacheCostTy LoopCost = 0;
for (const ReferenceGroupTy &RG : RefGroups) {
CacheCostTy RefGroupCost = computeRefGroupCacheCost(RG, L);
- LoopCost += RefGroupCost * TripCountsProduct;
+
+ // Saturate the cost to INT MAX if the value can overflow.
+ if (RefGroupCost >
+ (std::numeric_limits<CacheCostTy>::max() / TripCountsProduct))
+ LoopCost = std::numeric_limits<CacheCostTy>::max();
+ else
+ LoopCost += RefGroupCost * TripCountsProduct;
}
LLVM_DEBUG(dbgs().indent(2) << "Loop '" << L.getName()
diff --git a/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll b/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll
new file mode 100644
index 00000000000000..dad4e0bbfeb733
--- /dev/null
+++ b/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll
@@ -0,0 +1,37 @@
+; RUN: opt < %s -passes='print<loop-cache-cost>' -disable-output 2>&1 | FileCheck %s
+
+; For a loop with a very large iteration count, make sure the cost
+; calculation does not overflow:
+;
+; void a(int b) {
+; for (int c;; c += b)
+; for (long d = 0; d < -3ULL; d += 2ULL)
+; A[c][d][d] = 0;
+; }
+
+; CHECK: Loop 'for.cond' has cost = 9223372036854775807
+; CHECK: Loop 'for.body' has cost = 9223372036854775807
+
+ at A = local_unnamed_addr global [11 x [11 x [11 x i32]]] zeroinitializer, align 16
+
+define void @foo(i32 noundef %b) {
+entry:
+ %0 = sext i32 %b to i64
+ br label %for.cond
+
+for.cond:
+ %indvars.iv = phi i64 [ %indvars.iv.next, %for.cond.cleanup ], [ 0, %entry ]
+ br label %for.body
+
+for.cond.cleanup:
+ %indvars.iv.next = add nsw i64 %indvars.iv, %0
+ br label %for.cond
+
+for.body:
+ %d.010 = phi i64 [ 0, %for.cond ], [ %add, %for.body ]
+ %arrayidx3 = getelementptr inbounds [11 x [11 x [11 x i32]]], ptr @A, i64 0, i64 %indvars.iv, i64 %d.010, i64 %d.010
+ store i32 0, ptr %arrayidx3, align 4
+ %add = add nuw i64 %d.010, 2
+ %cmp = icmp ult i64 %d.010, -5
+ br i1 %cmp, label %for.body, label %for.cond.cleanup
+}
More information about the llvm-commits
mailing list