[llvm] [SCEV] Infer loop max trip count from memory accesses (PR #70361)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 31 08:55:16 PDT 2023
================
@@ -8191,6 +8204,133 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
return getSmallConstantTripMultiple(L, ExitCount);
}
+/// Collect all load/store instructions that must be executed in every iteration
+/// of loop \p L .
+static void
+collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
+ SmallVector<Instruction *, 4> &MemInsts) {
+ // It is difficult to tell if the load/store instruction is executed on every
+ // iteration inside an irregular loop.
+ if (!L->isLoopSimplifyForm() || !L->isInnermost())
+ return;
+
+ // FIXME: To make the case more typical, we only analyze loops that have one
+ // exiting block and the block must be the latch. It is easier to capture
+ // loops with memory access that will be executed in every iteration.
+ const BasicBlock *LoopLatch = L->getLoopLatch();
+ assert(LoopLatch && "normal form loop doesn't have a latch");
+ if (L->getExitingBlock() != LoopLatch)
+ return;
+
+ // We will not continue if sanitizer is enabled.
+ const Function *F = LoopLatch->getParent();
+ if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
+ F->hasFnAttribute(Attribute::SanitizeThread) ||
+ F->hasFnAttribute(Attribute::SanitizeMemory) ||
+ F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
+ F->hasFnAttribute(Attribute::SanitizeMemTag))
+ return;
+
+ for (auto *BB : L->getBlocks()) {
+ // We need to make sure that max execution time of MemAccessBB in loop
+ // represents latch max excution time. The BB below should be skipped:
+ // Entry
+ // │
+ // ┌─────▼─────┐
+ // │Loop Header◄─────┐
+ // └──┬──────┬─┘ │
+ // │ │ │
+ // ┌────────▼──┐ ┌─▼─────┐ │
+ // │MemAccessBB│ │OtherBB│ │
+ // └────────┬──┘ └─┬─────┘ │
+ // │ │ │
+ // ┌─▼──────▼─┐ │
+ // │Loop Latch├─────┘
+ // └────┬─────┘
+ // ▼
+ // Exit
+ if (!DT.dominates(BB, LoopLatch))
+ continue;
+
+ for (Instruction &I : *BB) {
+ if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
+ MemInsts.push_back(&I);
+ }
+ }
+}
+
+/// Return a SCEV representing the memory size of pointer \p V .
+static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
+ const DataLayout &DL,
+ const TargetLibraryInfo &TLI,
+ ScalarEvolution *SE) {
+ const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
+ if (!PtrBase)
+ return nullptr;
+ Value *Ptr = PtrBase->getValue();
+ uint64_t Size = 0;
+ if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
+ return nullptr;
+ return SE->getConstant(RTy, Size);
+}
+
+/// Check if we can make sure that the indices of a GEP instruction will not
+/// wrap.
+static bool checkIndexWrap(Value *Ptr, ScalarEvolution *SE) {
+ auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (!PtrGEP)
+ return false;
+ for (Value *Index : PtrGEP->indices()) {
+ Value *V = Index;
+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
+ V = cast<Instruction>(Index)->getOperand(0);
+ auto *SCEV = SE->getSCEV(V);
+ if (isa<SCEVCouldNotCompute>(SCEV))
+ return false;
+ auto *AryExpr = dyn_cast<SCEVNAryExpr>(SE->getSCEV(V));
+ if (!AryExpr)
+ continue;
+ if (!AryExpr->hasNoSelfWrap())
+ return false;
+ }
+ return true;
+}
+
+const SCEV *
+ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
+ SmallVector<Instruction *, 4> MemInsts;
+ collectExecLoadStoreInsideLoop(L, DT, MemInsts);
+
+ SmallVector<const SCEV *> InferCountColl;
+ const DataLayout &DL = getDataLayout();
+
+ for (Instruction *I : MemInsts) {
+ Value *Ptr = getLoadStorePointerOperand(I);
+ assert(Ptr && "empty pointer operand");
+ if (!checkIndexWrap(Ptr, this))
+ continue;
+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
+ if (!AddRec || !AddRec->isAffine())
----------------
shiltian wrote:
Actually I need to refine this part. Both what we have now (`checkIndexWrap`) and what you suggested don't work for the following case:
```
define void @ComputeMaxTripCountFromArrayIdxWrap(i32 signext %len) {
entry:
%a = alloca [257 x i32], align 4
%cmp4 = icmp sgt i32 %len, 0
br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
for.body.preheader:
br label %for.body
for.cond.cleanup.loopexit:
br label %for.cond.cleanup
for.cond.cleanup:
ret void
for.body:
%iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
%idxprom = zext i8 %iv to i64
%arrayidx = getelementptr inbounds [257 x i32], [257 x i32]* %a, i64 0, i64 %idxprom
store i32 0, i32* %arrayidx, align 4
%inc = add nuw nsw i8 %iv, 1
%inc_zext = zext i8 %inc to i32
%cmp = icmp slt i32 %inc_zext, %len
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
}
```
Here `%iv` is `i8` that is in the range of 0-255, so the max trip count in this case can not be deduced. All those wrap flags fail to tell `%iv` can actually wrap. We will probably need to adopt what is used in the original patch:
```
static const SCEV *howManyItersSelfWrap(const SCEV *V, ScalarEvolution *SE) {
if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
const SCEV *CUpper = SE->getConstant(SE->getUnsignedRangeMax(V));
const SCEV *CLower = SE->getConstant(SE->getUnsignedRangeMin(V));
const SCEV *Limit = SE->getMinusSCEV(CUpper, CLower);
const SCEV *Step = AddRec->getStepRecurrence(*SE);
return SE->getUDivCeilSCEV(Limit, Step);
}
return SE->getCouldNotCompute();
}
/// Returns the smaller one of the wraps that will occur in the indexes.
static const SCEV *getSmallCountOfIdxSelfWrap(Value *Ptr, ScalarEvolution *SE) {
auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
if (!PtrGEP)
return SE->getCouldNotCompute();
SmallVector<const SCEV *> CountColl;
for (Value *Index : PtrGEP->indices()) {
Value *V = Index;
if (isa<ZExtInst>(V) || isa<SExtInst>(V))
V = cast<Instruction>(Index)->getOperand(0);
const SCEV *Count = howManyItersSelfWrap(SE->getSCEV(V), SE);
if (!isa<SCEVCouldNotCompute>(Count)) {
CountColl.push_back(Count);
}
}
if (CountColl.empty())
return SE->getCouldNotCompute();
return SE->getUMinFromMismatchedTypes(CountColl);
}
...
ConstantInt *WrapVC = IdxWrapMap[Rec]->getValue();
ConstantInt *InferVC = InferCount->getValue();
if (InferVC->getValue().getZExtValue() > WrapVC->getValue().getZExtValue())
continue;
```
Or is there anything else that can tell the potential wrap here?
https://github.com/llvm/llvm-project/pull/70361
More information about the llvm-commits
mailing list