[llvm] [SCEV] Infer loop max trip count from memory accesses (PR #70361)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 18 11:27:29 PST 2023


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/70361

>From 679e95e5356513815ddf48e2e1b9e5d95b9bd66f Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Mon, 18 Dec 2023 13:50:38 -0500
Subject: [PATCH 1/2] [SCEV] Infer loop max trip count from memory accesses

Data references in a loop is assumed to not access elements over the statically
allocated size. We can therefore infer a loop max trip count from this undefined
behavior.

This patch is refined from the orignal one (https://reviews.llvm.org/D155049)
authored by @Peakulorain.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |   6 +
 llvm/lib/Analysis/ScalarEvolution.cpp         | 187 ++++++++++++++++-
 .../infer-trip-count-idx-wrap.ll              | 110 ++++++++++
 .../ScalarEvolution/infer-trip-count.ll       | 191 ++++++++++++++++++
 4 files changed, 493 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll
 create mode 100644 llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 4f1237c4b1f92b..c98705d60500d9 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -854,6 +854,12 @@ class ScalarEvolution {
   unsigned getSmallConstantTripMultiple(const Loop *L,
                                         const BasicBlock *ExitingBlock);
 
+  /// Return the upper bound of the loop trip count infered from memory access.
+  /// This can not access bytes starting outside the statically allocated size
+  /// without being immediate UB. Returns SCEVCouldNotCompute if the trip count
+  /// could not be inferred.
+  const SCEV *getConstantMaxTripCountFromMemAccess(const Loop *L);
+
   /// The terms "backedge taken count" and "exit count" are used
   /// interchangeably to refer to the number of times the backedge of a loop
   /// has executed before the loop is exited.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 580fe112fcd7bd..2076c007834db7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
     cl::desc("Infer nuw/nsw flags using context where suitable"),
     cl::init(true));
 
+static cl::opt<bool> UseMemoryAccessUBForBEInference(
+    "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
+    cl::desc("Infer loop max trip count from memory access"), cl::init(false));
+
 //===----------------------------------------------------------------------===//
 //                           SCEV class definitions
 //===----------------------------------------------------------------------===//
@@ -8135,7 +8139,16 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L,
 unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
   const auto *MaxExitCount =
       dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
-  return getConstantTripCount(MaxExitCount);
+  unsigned MaxExitCountN = getConstantTripCount(MaxExitCount);
+  if (UseMemoryAccessUBForBEInference) {
+    auto *MaxInferCount = getConstantMaxTripCountFromMemAccess(L);
+    if (auto *InferCount = dyn_cast<SCEVConstant>(MaxInferCount)) {
+      unsigned InferValue = InferCount->getValue()->getZExtValue();
+      MaxExitCountN =
+          MaxExitCountN == 0 ? InferValue : std::min(MaxExitCountN, InferValue);
+    }
+  }
+  return MaxExitCountN;
 }
 
 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
@@ -8190,6 +8203,167 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
   return getSmallConstantTripMultiple(L, ExitCount);
 }
 
+/// Collect all load/store instructions that must be executed in every iteration
+/// of loop \p L .
+static void
+collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
+                               SmallVector<Instruction *, 4> &MemInsts) {
+  // It is difficult to tell if the load/store instruction is executed on every
+  // iteration inside an irregular loop.
+  if (!L->isLoopSimplifyForm() || !L->isInnermost())
+    return;
+
+  // FIXME: To make the case more typical, we only analyze loops that have one
+  // exiting block and the block must be the latch. It is easier to capture
+  // loops with memory access that will be executed in every iteration.
+  const BasicBlock *LoopLatch = L->getLoopLatch();
+  assert(LoopLatch && "normal form loop doesn't have a latch");
+  if (L->getExitingBlock() != LoopLatch)
+    return;
+
+  // We will not continue if sanitizer is enabled.
+  const Function *F = LoopLatch->getParent();
+  if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
+      F->hasFnAttribute(Attribute::SanitizeThread) ||
+      F->hasFnAttribute(Attribute::SanitizeMemory) ||
+      F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
+      F->hasFnAttribute(Attribute::SanitizeMemTag))
+    return;
+
+  for (auto *BB : L->getBlocks()) {
+    // We need to make sure that max execution time of MemAccessBB in loop
+    // represents latch max excution time. The BB below should be skipped:
+    //            Entry
+    //              │
+    //        ┌─────▼─────┐
+    //        │Loop Header◄─────┐
+    //        └──┬──────┬─┘     │
+    //           │      │       │
+    //  ┌────────▼──┐ ┌─▼─────┐ │
+    //  │MemAccessBB│ │OtherBB│ │
+    //  └────────┬──┘ └─┬─────┘ │
+    //           │      │       │
+    //         ┌─▼──────▼─┐     │
+    //         │Loop Latch├─────┘
+    //         └────┬─────┘
+    //              ▼
+    //             Exit
+    if (!DT.dominates(BB, LoopLatch))
+      continue;
+
+    for (Instruction &I : *BB) {
+      if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
+        MemInsts.push_back(&I);
+    }
+  }
+}
+
+/// Return a SCEV representing the memory size of pointer \p V .
+static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
+                                       const DataLayout &DL,
+                                       const TargetLibraryInfo &TLI,
+                                       ScalarEvolution *SE) {
+  const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
+  if (!PtrBase)
+    return nullptr;
+  Value *Ptr = PtrBase->getValue();
+  uint64_t Size = 0;
+  if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
+    return nullptr;
+  return SE->getConstant(RTy, Size);
+}
+
+/// Get the range of given index represented by \p AddRec.
+static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
+                                 ScalarEvolution *SE) {
+  const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
+                                      SE->getUnsignedRangeMin(AddRec));
+  const SCEV *Step = AddRec->getStepRecurrence(*SE);
+  return SE->getUDivCeilSCEV(Range, Step);
+}
+
+/// Check whether the index can wrap and if we can still infer max trip count
+/// given the max trip count inferred from memory access.
+static const SCEV *checkIndexWrap(Value *Ptr, ScalarEvolution *SE,
+                                  const SCEVConstant *MaxExecCount) {
+  SmallVector<const SCEV *> InferCountColl;
+  auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!PtrGEP)
+    return SE->getCouldNotCompute();
+  for (Value *Index : PtrGEP->indices()) {
+    Value *V = Index;
+    if (isa<ZExtInst>(V) || isa<SExtInst>(V))
+      V = cast<Instruction>(Index)->getOperand(0);
+    auto *SCEV = SE->getSCEV(V);
+    if (isa<SCEVCouldNotCompute>(SCEV))
+      return SE->getCouldNotCompute();
+    auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
+    if (!AddRec)
+      continue;
+    auto *IndexRange = getIndexRange(AddRec, SE);
+    if (AddRec->hasNoSelfWrap()) {
+      InferCountColl.push_back(
+          SE->getUMinFromMismatchedTypes(IndexRange, MaxExecCount));
+    } else {
+      auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
+      if (!IndexRangeC)
+        continue;
+      if (MaxExecCount->getValue()->getZExtValue() >
+          IndexRangeC->getValue()->getZExtValue())
+        InferCountColl.push_back(IndexRange);
+      else
+        InferCountColl.push_back(MaxExecCount);
+    }
+  }
+
+  if (InferCountColl.empty())
+    return SE->getCouldNotCompute();
+
+  return SE->getUMinFromMismatchedTypes(InferCountColl);
+}
+
+const SCEV *
+ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
+  SmallVector<Instruction *, 4> MemInsts;
+  collectExecLoadStoreInsideLoop(L, DT, MemInsts);
+
+  SmallVector<const SCEV *> InferCountColl;
+  const DataLayout &DL = getDataLayout();
+
+  for (Instruction *I : MemInsts) {
+    Value *Ptr = getLoadStorePointerOperand(I);
+    assert(Ptr && "empty pointer operand");
+    auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
+    if (!AddRec || !AddRec->isAffine())
+      continue;
+    const SCEV *PtrBase = getPointerBase(AddRec);
+    const SCEV *Step = AddRec->getStepRecurrence(*this);
+    const SCEV *MemSize =
+        getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
+    if (!MemSize)
+      continue;
+    // Now we can infer a max execution time by MemLength/StepLength.
+    auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
+    if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
+      continue;
+    // Now we check the wrap. We can still explore the max trip count in the
+    // following two cases:
+    // 1. If the index can potentially wrap but the max trip count inferred from
+    // memory access is within the range of index.
+    // 2. If the index can't wrap, then the max trip count is:
+    // min(range of index, max value inferred from memory access).
+    auto *Res = checkIndexWrap(Ptr, this, MaxExecCount);
+    if (isa<SCEVCouldNotCompute>(Res))
+      continue;
+    InferCountColl.push_back(Res);
+  }
+
+  if (InferCountColl.empty())
+    return getCouldNotCompute();
+
+  return getUMinFromMismatchedTypes(InferCountColl);
+}
+
 const SCEV *ScalarEvolution::getExitCount(const Loop *L,
                                           const BasicBlock *ExitingBlock,
                                           ExitCountKind Kind) {
@@ -13477,6 +13651,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
     OS << ": ";
     OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
   }
+
+  if (UseMemoryAccessUBForBEInference) {
+    unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
+    OS << "Loop ";
+    L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+    OS << ": ";
+    if (SmallMaxTrip)
+      OS << "Small constant max trip is " << SmallMaxTrip << "\n";
+    else
+      OS << "Small constant max trip couldn't be computed.\n";
+  }
 }
 
 namespace llvm {
diff --git a/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll b/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll
new file mode 100644
index 00000000000000..341090e868cc7f
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll
@@ -0,0 +1,110 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
+; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 -scalar-evolution-infer-max-trip-count-from-memory-access 2>&1 | FileCheck %s
+
+define void @ComputeMaxTripCountFromArrayIdxWrap(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 255
+;
+entry:
+  %a = alloca [256 x i32], align 4
+  %cmp4 = icmp sgt i32 %len, 0
+  br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:
+  br label %for.body
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+  %idxprom = zext i8 %iv to i64
+  %arrayidx = getelementptr inbounds [256 x i32], [256 x i32]* %a, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx, align 4
+  %inc = add nuw i8 %iv, 1
+  %inc_zext = zext i8 %inc to i32
+  %cmp = icmp slt i32 %inc_zext, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}
+
+define void @ComputeMaxTripCountFromArrayIdxWrap2(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap2'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap2
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 127
+;
+entry:
+  %a = alloca [127 x i32], align 4
+  %cmp4 = icmp sgt i32 %len, 0
+  br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:
+  br label %for.body
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+  %idxprom = zext i8 %iv to i64
+  %arrayidx = getelementptr inbounds [127 x i32], [127 x i32]* %a, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx, align 4
+  %inc = add nuw i8 %iv, 1
+  %inc_zext = zext i8 %inc to i32
+  %cmp = icmp slt i32 %inc_zext, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}
+
+define void @ComputeMaxTripCountFromArrayIdxWrap3(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap3'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap3
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 20
+;
+entry:
+  %a = alloca [20 x i32], align 4
+  %cmp4 = icmp sgt i32 %len, 0
+  br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:
+  br label %for.body
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+  %idxprom = zext i8 %iv to i64
+  %arrayidx = getelementptr inbounds [20 x i32], [20 x i32]* %a, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx, align 4
+  %inc = add nuw nsw i8 %iv, 1
+  %inc_zext = zext i8 %inc to i32
+  %cmp = icmp slt i32 %inc_zext, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}
diff --git a/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll b/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll
new file mode 100644
index 00000000000000..c8fe81a01b2c46
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll
@@ -0,0 +1,191 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
+; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 -scalar-evolution-infer-max-trip-count-from-memory-access 2>&1 | FileCheck %s
+
+define void @ComputeMaxTripCountFromArrayNormal(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromArrayNormal'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayNormal
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 7
+;
+entry:
+  %a = alloca [7 x i32], align 4
+  %cmp4 = icmp sgt i32 %len, 0
+  br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:
+  br label %for.body
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+  %idxprom = zext i32 %iv to i64
+  %arrayidx = getelementptr inbounds [7 x i32], [7 x i32]* %a, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx, align 4
+  %inc = add nuw nsw i32 %iv, 1
+  %cmp = icmp slt i32 %inc, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}
+
+
+define void @ComputeMaxTripCountFromZeroArray(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromZeroArray'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromZeroArray
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Small constant max trip couldn't be computed.
+;
+entry:
+  %a = alloca [0 x i32], align 4
+  %cmp4 = icmp sgt i32 %len, 0
+  br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:
+  br label %for.body
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+  %idxprom = zext i32 %iv to i64
+  %arrayidx = getelementptr inbounds [0 x i32], [0 x i32]* %a, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx, align 4
+  %inc = add nuw nsw i32 %iv, 1
+  %cmp = icmp slt i32 %inc, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}
+
+define void @ComputeMaxTripCountFromExtremArray(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromExtremArray'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromExtremArray
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 2147483646
+;
+entry:
+  %a = alloca [4294967295 x i1], align 4
+  %cmp4 = icmp sgt i32 %len, 0
+  br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
+
+for.body.preheader:
+  br label %for.body
+
+for.cond.cleanup.loopexit:
+  br label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+
+for.body:
+  %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
+  %idxprom = zext i32 %iv to i64
+  %arrayidx = getelementptr inbounds [4294967295 x i1], [4294967295 x i1]* %a, i64 0, i64 %idxprom
+  store i1 0, i1* %arrayidx, align 4
+  %inc = add nuw nsw i32 %iv, 1
+  %cmp = icmp slt i32 %inc, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}
+
+
+define void @ComputeMaxTripCountFromArrayInBranch(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromArrayInBranch'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayInBranch
+; CHECK-NEXT:  Loop %for.cond: backedge-taken count is (0 smax %len)
+; CHECK-NEXT:  Loop %for.cond: constant max backedge-taken count is 2147483647
+; CHECK-NEXT:  Loop %for.cond: symbolic max backedge-taken count is (0 smax %len)
+; CHECK-NEXT:  Loop %for.cond: Predicated backedge-taken count is (0 smax %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.cond: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.cond: Small constant max trip is 2147483648
+;
+entry:
+  %a = alloca [8 x i32], align 4
+  br label %for.cond
+
+for.cond:
+  %iv = phi i32 [ %inc, %for.inc ], [ 0, %entry ]
+  %cmp = icmp slt i32 %iv, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:
+  br label %for.end
+
+for.body:
+  %cmp1 = icmp slt i32 %iv, 8
+  br i1 %cmp1, label %if.then, label %if.end
+
+if.then:
+  %idxprom = sext i32 %iv to i64
+  %arrayidx = getelementptr inbounds [8 x i32], [8 x i32]* %a, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx, align 4
+  br label %if.end
+
+if.end:
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i32 %iv, 1
+  br label %for.cond
+
+for.end:
+  ret void
+}
+
+define void @ComputeMaxTripCountFromMultiDimArray(i32 signext %len) {
+; CHECK-LABEL: 'ComputeMaxTripCountFromMultiDimArray'
+; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromMultiDimArray
+; CHECK-NEXT:  Loop %for.cond: backedge-taken count is (0 smax %len)
+; CHECK-NEXT:  Loop %for.cond: constant max backedge-taken count is 2147483647
+; CHECK-NEXT:  Loop %for.cond: symbolic max backedge-taken count is (0 smax %len)
+; CHECK-NEXT:  Loop %for.cond: Predicated backedge-taken count is (0 smax %len)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:  Loop %for.cond: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.cond: Small constant max trip is 2147483648
+;
+entry:
+  %a = alloca [3 x [5 x i32]], align 4
+  br label %for.cond
+
+for.cond:
+  %iv = phi i32 [ %inc, %for.inc ], [ 0, %entry ]
+  %cmp = icmp slt i32 %iv, %len
+  br i1 %cmp, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:
+  br label %for.end
+
+for.body:
+  %arrayidx = getelementptr inbounds [3 x [5 x i32]], [3 x [5 x i32]]* %a, i64 0, i64 3
+  %idxprom = sext i32 %iv to i64
+  %arrayidx1 = getelementptr inbounds [5 x i32], [5 x i32]* %arrayidx, i64 0, i64 %idxprom
+  store i32 0, i32* %arrayidx1, align 4
+  br label %for.inc
+
+for.inc:
+  %inc = add nsw i32 %iv, 1
+  br label %for.cond
+
+for.end:
+  ret void
+}

>From b773c5d8ca80fbbbf6ea8827e8988c0df81c40ea Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Mon, 18 Dec 2023 13:50:38 -0500
Subject: [PATCH 2/2] Rewrite in `computeExitLimit`

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 11 ++-
 llvm/lib/Analysis/ScalarEvolution.cpp         | 80 +++++++++++++------
 .../infer-trip-count-idx-wrap.ll              | 12 +--
 .../ScalarEvolution/infer-trip-count.ll       | 10 +--
 4 files changed, 72 insertions(+), 41 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c98705d60500d9..ffa9d62744c101 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -854,12 +854,6 @@ class ScalarEvolution {
   unsigned getSmallConstantTripMultiple(const Loop *L,
                                         const BasicBlock *ExitingBlock);
 
-  /// Return the upper bound of the loop trip count infered from memory access.
-  /// This can not access bytes starting outside the statically allocated size
-  /// without being immediate UB. Returns SCEVCouldNotCompute if the trip count
-  /// could not be inferred.
-  const SCEV *getConstantMaxTripCountFromMemAccess(const Loop *L);
-
   /// The terms "backedge taken count" and "exit count" are used
   /// interchangeably to refer to the number of times the backedge of a loop
   /// has executed before the loop is exited.
@@ -1159,6 +1153,8 @@ class ScalarEvolution {
                                      bool ExitIfTrue, bool ControlsOnlyExit,
                                      bool AllowPredicates = false);
 
+  ExitLimit computeExitLimitFromMemAccess(const Loop *L);
+
   /// A predicate is said to be monotonically increasing if may go from being
   /// false to being true as the loop iterates, but never the other way
   /// around.  A predicate is said to be monotonically decreasing if may go
@@ -1803,6 +1799,9 @@ class ScalarEvolution {
                                          Value *ExitCond, bool ExitIfTrue,
                                          bool ControlsOnlyExit,
                                          bool AllowPredicates);
+  ExitLimit computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
+                                                const Loop *L);
+  ExitLimit computeExitLimitFromMemAccessImpl(const Loop *L);
   std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
       ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
       bool ControlsOnlyExit, bool AllowPredicates);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 2076c007834db7..4fce8ac9b919b5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8139,16 +8139,7 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L,
 unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
   const auto *MaxExitCount =
       dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
-  unsigned MaxExitCountN = getConstantTripCount(MaxExitCount);
-  if (UseMemoryAccessUBForBEInference) {
-    auto *MaxInferCount = getConstantMaxTripCountFromMemAccess(L);
-    if (auto *InferCount = dyn_cast<SCEVConstant>(MaxInferCount)) {
-      unsigned InferValue = InferCount->getValue()->getZExtValue();
-      MaxExitCountN =
-          MaxExitCountN == 0 ? InferValue : std::min(MaxExitCountN, InferValue);
-    }
-  }
-  return MaxExitCountN;
+  return getConstantTripCount(MaxExitCount);
 }
 
 unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
@@ -8213,13 +8204,9 @@ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
   if (!L->isLoopSimplifyForm() || !L->isInnermost())
     return;
 
-  // FIXME: To make the case more typical, we only analyze loops that have one
-  // exiting block and the block must be the latch. It is easier to capture
-  // loops with memory access that will be executed in every iteration.
   const BasicBlock *LoopLatch = L->getLoopLatch();
   assert(LoopLatch && "normal form loop doesn't have a latch");
-  if (L->getExitingBlock() != LoopLatch)
-    return;
+  assert(L->getExitingBlock() == LoopLatch);
 
   // We will not continue if sanitizer is enabled.
   const Function *F = LoopLatch->getParent();
@@ -8322,8 +8309,8 @@ static const SCEV *checkIndexWrap(Value *Ptr, ScalarEvolution *SE,
   return SE->getUMinFromMismatchedTypes(InferCountColl);
 }
 
-const SCEV *
-ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
   SmallVector<Instruction *, 4> MemInsts;
   collectExecLoadStoreInsideLoop(L, DT, MemInsts);
 
@@ -8361,7 +8348,27 @@ ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
   if (InferCountColl.empty())
     return getCouldNotCompute();
 
-  return getUMinFromMismatchedTypes(InferCountColl);
+  const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
+
+  return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
+}
+
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
+                                                     const Loop *L) {
+  // We don't really need them but the cache does.
+  constexpr Value *ExitCond = nullptr;
+  constexpr const bool ExitIfTrue = true;
+  constexpr const bool ControlsOnlyExit = true;
+  constexpr const bool AllowPredicates = true;
+
+  if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
+                                AllowPredicates))
+    return *MaybeEL;
+
+  ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
+  Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
+  return EL;
 }
 
 const SCEV *ScalarEvolution::getExitCount(const Loop *L,
@@ -8946,6 +8953,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
   if (!Latch || !DT.dominates(ExitingBlock, Latch))
     return getCouldNotCompute();
 
+  // FIXME: To make the case more typical, we only analyze loops that have one
+  // exiting block and the block must be the latch. It is easier to capture
+  // loops with memory access that will be executed in every iteration.
+  const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
+  if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
+    assert(Latch == ExitingBlock);
+    auto EL = computeExitLimitFromMemAccess(L);
+    PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
+  }
+
   bool IsOnlyExit = (L->getExitingBlock() != nullptr);
   Instruction *Term = ExitingBlock->getTerminator();
   if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8954,9 +8971,13 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
     assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
            "It should have one successor in loop and one exit block!");
     // Proceed to the next level to examine the exit condition expression.
-    return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
-                                    /*ControlsOnlyExit=*/IsOnlyExit,
-                                    AllowPredicates);
+    ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
+                                            /*ControlsOnlyExit=*/IsOnlyExit,
+                                            AllowPredicates);
+    if (!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
+      EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
+          EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
+    return EL;
   }
 
   if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8969,9 +8990,13 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
         Exit = SBB;
       }
     assert(Exit && "Exiting block must have at least one exit");
-    return computeExitLimitFromSingleExitSwitch(
-        L, SI, Exit,
-        /*ControlsOnlyExit=*/IsOnlyExit);
+    ExitLimit EL =
+        computeExitLimitFromSingleExitSwitch(L, SI, Exit,
+                                             /*ControlsOnlyExit=*/IsOnlyExit);
+    if (!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
+      EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
+          EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
+    return EL;
   }
 
   return getCouldNotCompute();
@@ -8985,6 +9010,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
                                         ControlsOnlyExit, AllowPredicates);
 }
 
+ScalarEvolution::ExitLimit
+ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
+  ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
+                                          /* AllowPredicates */ true);
+  return computeExitLimitFromMemAccessCached(Cache, L);
+}
+
 std::optional<ScalarEvolution::ExitLimit>
 ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
                                       bool ExitIfTrue, bool ControlsOnlyExit,
diff --git a/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll b/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll
index 341090e868cc7f..2971cbd50d94a8 100644
--- a/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll
+++ b/llvm/test/Analysis/ScalarEvolution/infer-trip-count-idx-wrap.ll
@@ -5,12 +5,12 @@ define void @ComputeMaxTripCountFromArrayIdxWrap(i32 signext %len) {
 ; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap'
 ; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 255
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-; CHECK-NEXT:  Loop %for.body: Small constant max trip is 255
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 256
 ;
 entry:
   %a = alloca [256 x i32], align 4
@@ -41,12 +41,12 @@ define void @ComputeMaxTripCountFromArrayIdxWrap2(i32 signext %len) {
 ; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap2'
 ; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap2
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 127
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-; CHECK-NEXT:  Loop %for.body: Small constant max trip is 127
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 128
 ;
 entry:
   %a = alloca [127 x i32], align 4
@@ -77,12 +77,12 @@ define void @ComputeMaxTripCountFromArrayIdxWrap3(i32 signext %len) {
 ; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap3'
 ; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap3
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 20
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-; CHECK-NEXT:  Loop %for.body: Small constant max trip is 20
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 21
 ;
 entry:
   %a = alloca [20 x i32], align 4
diff --git a/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll b/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll
index c8fe81a01b2c46..7c52385adae754 100644
--- a/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/infer-trip-count.ll
@@ -5,12 +5,12 @@ define void @ComputeMaxTripCountFromArrayNormal(i32 signext %len) {
 ; CHECK-LABEL: 'ComputeMaxTripCountFromArrayNormal'
 ; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromArrayNormal
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 7
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-; CHECK-NEXT:  Loop %for.body: Small constant max trip is 7
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 8
 ;
 entry:
   %a = alloca [7 x i32], align 4
@@ -41,12 +41,12 @@ define void @ComputeMaxTripCountFromZeroArray(i32 signext %len) {
 ; CHECK-LABEL: 'ComputeMaxTripCountFromZeroArray'
 ; CHECK-NEXT:  Determining loop execution counts for: @ComputeMaxTripCountFromZeroArray
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + %len)
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 2147483646
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 0
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-; CHECK-NEXT:  Loop %for.body: Small constant max trip couldn't be computed.
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 1
 ;
 entry:
   %a = alloca [0 x i32], align 4
@@ -81,7 +81,7 @@ define void @ComputeMaxTripCountFromExtremArray(i32 signext %len) {
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %len)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
-; CHECK-NEXT:  Loop %for.body: Small constant max trip is 2147483646
+; CHECK-NEXT:  Loop %for.body: Small constant max trip is 2147483647
 ;
 entry:
   %a = alloca [4294967295 x i1], align 4



More information about the llvm-commits mailing list