[llvm] f6d110e - [LAA] Make getPtrStride return Option instead of overloading zero as error value [nfc]
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 27 15:57:26 PDT 2022
Author: Philip Reames
Date: 2022-09-27T15:55:44-07:00
New Revision: f6d110e26f1bdd3b4462f7fda620e07c425ccf76
URL: https://github.com/llvm/llvm-project/commit/f6d110e26f1bdd3b4462f7fda620e07c425ccf76
DIFF: https://github.com/llvm/llvm-project/commit/f6d110e26f1bdd3b4462f7fda620e07c425ccf76.diff
LOG: [LAA] Make getPtrStride return Option instead of overloading zero as error value [nfc]
This is purely NFC restructure in advance of a change which actually exposes zero strides. This is mostly because I find this interface confusing each time I look at it.
Added:
Modified:
llvm/include/llvm/Analysis/LoopAccessAnalysis.h
llvm/lib/Analysis/LoopAccessAnalysis.cpp
llvm/lib/Analysis/VectorUtils.cpp
llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index 4ae55b80294e7..82a4e19a8247c 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -720,7 +720,7 @@ const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
Value *Ptr);
/// If the pointer has a constant stride return it in units of the access type
-/// size. Otherwise return zero.
+/// size. Otherwise return None.
///
/// Ensure that it does not wrap in the address space, assuming the predicate
/// associated with \p PSE is true.
@@ -729,10 +729,11 @@ const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
/// to \p PtrToStride and therefore add further predicates to \p PSE.
/// The \p Assume parameter indicates if we are allowed to make additional
/// run-time assumptions.
-int64_t getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
- const Loop *Lp,
- const ValueToValueMap &StridesMap = ValueToValueMap(),
- bool Assume = false, bool ShouldCheckWrap = true);
+Optional<int64_t>
+getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
+ const Loop *Lp,
+ const ValueToValueMap &StridesMap = ValueToValueMap(),
+ bool Assume = false, bool ShouldCheckWrap = true);
/// Returns the distance between the pointers \p PtrA and \p PtrB iff they are
/// compatible and it is possible to calculate the distance between them. This
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index f06e412352859..8dec0533d51b3 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -758,7 +758,7 @@ static bool isNoWrap(PredicatedScalarEvolution &PSE,
if (PSE.getSE()->isLoopInvariant(PtrScev, L))
return true;
- int64_t Stride = getPtrStride(PSE, AccessTy, Ptr, L, Strides);
+ int64_t Stride = getPtrStride(PSE, AccessTy, Ptr, L, Strides).value_or(0);
if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW))
return true;
@@ -1365,17 +1365,18 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
}
/// Check whether the access through \p Ptr has a constant stride.
-int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
- Value *Ptr, const Loop *Lp,
- const ValueToValueMap &StridesMap, bool Assume,
- bool ShouldCheckWrap) {
+Optional<int64_t>
+llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
+ Value *Ptr, const Loop *Lp,
+ const ValueToValueMap &StridesMap, bool Assume,
+ bool ShouldCheckWrap) {
Type *Ty = Ptr->getType();
assert(Ty->isPointerTy() && "Unexpected non-ptr");
if (isa<ScalableVectorType>(AccessTy)) {
LLVM_DEBUG(dbgs() << "LAA: Bad stride - Scalable object: " << *AccessTy
<< "\n");
- return 0;
+ return None;
}
const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
@@ -1387,14 +1388,14 @@ int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
if (!AR) {
LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr
<< " SCEV: " << *PtrScev << "\n");
- return 0;
+ return None;
}
// The access function must stride over the innermost loop.
if (Lp != AR->getLoop()) {
LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop "
<< *Ptr << " SCEV: " << *AR << "\n");
- return 0;
+ return None;
}
// The address calculation must not wrap. Otherwise, a dependence could be
@@ -1422,7 +1423,7 @@ int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
LLVM_DEBUG(
dbgs() << "LAA: Bad stride - Pointer may wrap in the address space "
<< *Ptr << " SCEV: " << *AR << "\n");
- return 0;
+ return None;
}
}
@@ -1434,7 +1435,7 @@ int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
if (!C) {
LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr
<< " SCEV: " << *AR << "\n");
- return 0;
+ return None;
}
auto &DL = Lp->getHeader()->getModule()->getDataLayout();
@@ -1444,7 +1445,7 @@ int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
// Huge step value - give up.
if (APStepVal.getBitWidth() > 64)
- return 0;
+ return None;
int64_t StepVal = APStepVal.getSExtValue();
@@ -1452,7 +1453,7 @@ int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
int64_t Stride = StepVal / Size;
int64_t Rem = StepVal % Size;
if (Rem)
- return 0;
+ return None;
// If the SCEV could wrap but we have an inbounds gep with a unit stride we
// know we can't "wrap around the address space". In case of address space
@@ -1469,7 +1470,7 @@ int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy,
<< "LAA: Added an overflow assumption\n");
PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
} else
- return 0;
+ return None;
}
return Stride;
@@ -1846,9 +1847,9 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
return Dependence::Unknown;
int64_t StrideAPtr =
- getPtrStride(PSE, ATy, APtr, InnermostLoop, Strides, true);
+ getPtrStride(PSE, ATy, APtr, InnermostLoop, Strides, true).value_or(0);
int64_t StrideBPtr =
- getPtrStride(PSE, BTy, BPtr, InnermostLoop, Strides, true);
+ getPtrStride(PSE, BTy, BPtr, InnermostLoop, Strides, true).value_or(0);
const SCEV *Src = PSE.getSCEV(APtr);
const SCEV *Sink = PSE.getSCEV(BPtr);
@@ -2350,7 +2351,7 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
bool IsReadOnlyPtr = false;
Type *AccessTy = getLoadStoreType(LD);
if (Seen.insert({Ptr, AccessTy}).second ||
- !getPtrStride(*PSE, LD->getType(), Ptr, TheLoop, SymbolicStrides)) {
+ !getPtrStride(*PSE, LD->getType(), Ptr, TheLoop, SymbolicStrides).value_or(0)) {
++NumReads;
IsReadOnlyPtr = true;
}
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 57a373056b2b9..b4398170a34c5 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1117,8 +1117,9 @@ void InterleavedAccessInfo::collectConstStrideAccesses(
// wrap around the address space we would do a memory access at nullptr
// even without the transformation. The wrapping checks are therefore
// deferred until after we've formed the interleaved groups.
- int64_t Stride = getPtrStride(PSE, ElementTy, Ptr, TheLoop, Strides,
- /*Assume=*/true, /*ShouldCheckWrap=*/false);
+ int64_t Stride =
+ getPtrStride(PSE, ElementTy, Ptr, TheLoop, Strides,
+ /*Assume=*/true, /*ShouldCheckWrap=*/false).value_or(0);
const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
uint64_t Size = DL.getTypeAllocSize(ElementTy);
@@ -1338,7 +1339,7 @@ void InterleavedAccessInfo::analyzeInterleaving(
Value *MemberPtr = getLoadStorePointerOperand(Member);
Type *AccessTy = getLoadStoreType(Member);
if (getPtrStride(PSE, AccessTy, MemberPtr, TheLoop, Strides,
- /*Assume=*/false, /*ShouldCheckWrap=*/true))
+ /*Assume=*/false, /*ShouldCheckWrap=*/true).value_or(0))
return false;
LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to "
<< FirstOrLast
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 1d6e29510950c..6ea6073909e77 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -2190,7 +2190,7 @@ static bool canTailPredicateLoop(Loop *L, LoopInfo *LI, ScalarEvolution &SE,
if (isa<StoreInst>(I) || isa<LoadInst>(I)) {
Value *Ptr = getLoadStorePointerOperand(&I);
Type *AccessTy = getLoadStoreType(&I);
- int64_t NextStride = getPtrStride(PSE, AccessTy, Ptr, L);
+ int64_t NextStride = getPtrStride(PSE, AccessTy, Ptr, L).value_or(0);
if (NextStride == 1) {
// TODO: for now only allow consecutive strides of 1. We could support
// other strides as long as it is uniform, but let's keep it simple
diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
index 13049c701e68a..8db86e5fbd1da 100644
--- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -109,8 +109,8 @@ struct StoreToLoadForwardingCandidate {
// Currently we only support accesses with unit stride. FIXME: we should be
// able to handle non unit stirde as well as long as the stride is equal to
// the dependence distance.
- if (getPtrStride(PSE, LoadType, LoadPtr, L) != 1 ||
- getPtrStride(PSE, LoadType, StorePtr, L) != 1)
+ if (getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0) != 1 ||
+ getPtrStride(PSE, LoadType, StorePtr, L).value_or(0) != 1)
return false;
unsigned TypeByteSize = DL.getTypeAllocSize(const_cast<Type *>(LoadType));
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index e1d5dc735203d..b3bea8c5be7e8 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -456,7 +456,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy,
PGSOQueryType::IRPass);
bool CanAddPredicate = !OptForSize;
int Stride = getPtrStride(PSE, AccessTy, Ptr, TheLoop, Strides,
- CanAddPredicate, false);
+ CanAddPredicate, false).value_or(0);
if (Stride == 1 || Stride == -1)
return Stride;
return 0;
More information about the llvm-commits
mailing list