[llvm] 733b8b2 - [LAA] Simplify identification of speculatable strides [nfc]
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu May 11 11:48:50 PDT 2023
Author: Philip Reames
Date: 2023-05-11T11:48:21-07:00
New Revision: 733b8b2b4919e76e1768919461fd5eb211e41d79
URL: https://github.com/llvm/llvm-project/commit/733b8b2b4919e76e1768919461fd5eb211e41d79
DIFF: https://github.com/llvm/llvm-project/commit/733b8b2b4919e76e1768919461fd5eb211e41d79.diff
LOG: [LAA] Simplify identification of speculatable strides [nfc]
Mostly just avoiding the need to keep both Value and SCEVs flowing through with consistent handling. We can do everything in terms of SCEV - aside from the profitability heuristics which are now isolated in one spot.
Added:
Modified:
llvm/lib/Analysis/LoopAccessAnalysis.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 358f97f83d40..5a0b1abe96d9 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -2610,7 +2610,7 @@ static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
/// Get the stride of a pointer access in a loop. Looks for symbolic
/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
-static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
+static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
if (!PtrTy || PtrTy->isAggregateType())
return nullptr;
@@ -2664,28 +2664,27 @@ static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
}
}
- // Strip off casts.
- Type *StripedOffRecurrenceCast = nullptr;
- if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) {
- StripedOffRecurrenceCast = C->getType();
- V = C->getOperand();
- }
+ // Note that the restriction after this loop invariant check are only
+ // profitability restrictions.
+ if (!SE->isLoopInvariant(V, Lp))
+ return nullptr;
// Look for the loop invariant symbolic value.
const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
- if (!U)
- return nullptr;
+ if (!U) {
+ const auto *C = dyn_cast<SCEVIntegralCastExpr>(V);
+ if (!C)
+ return nullptr;
+ U = dyn_cast<SCEVUnknown>(C->getOperand());
+ if (!U)
+ return nullptr;
- Value *Stride = U->getValue();
- if (!Lp->isLoopInvariant(Stride))
- return nullptr;
-
- // If we have stripped off the recurrence cast we have to make sure that we
- // return the value that is used in this loop so that we can replace it later.
- if (StripedOffRecurrenceCast)
- Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast);
+ // Match legacy behavior - this is not needed for correctness
+ if (!getUniqueCastUse(U->getValue(), Lp, V->getType()))
+ return nullptr;
+ }
- return Stride;
+ return V;
}
void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
@@ -2699,13 +2698,13 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// computation of an interesting IV - but we chose not to as we
// don't have a cost model here, and broadening the scope exposes
// far too many unprofitable cases.
- Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
- if (!Stride)
+ const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+ if (!StrideExpr)
return;
LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
"versioning:");
- LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
+ LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n");
if (!SpeculateUnitStride) {
LLVM_DEBUG(dbgs() << " Chose not to due to -laa-speculate-unit-stride\n");
@@ -2725,7 +2724,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// of various possible stride specializations, considering the alternatives
// of using gather/scatters (if available).
- const SCEV *StrideExpr = PSE->getSCEV(Stride);
const SCEV *BETakenCount = PSE->getBackedgeTakenCount();
// Match the types so we can compare the stride and the BETakenCount.
@@ -2756,8 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// Strip back off the integer cast, and check that our result is a
// SCEVUnknown as we expect.
- Value *StrideVal = stripIntegerCast(Stride);
- SymbolicStrides[Ptr] = cast<SCEVUnknown>(PSE->getSCEV(StrideVal));
+ const SCEV *StrideBase = StrideExpr;
+ if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
+ StrideBase = C->getOperand();
+ SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
}
LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
More information about the llvm-commits
mailing list