[llvm] [SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. (PR #118184)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Nov 30 11:20:40 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Florian Hahn (fhahn)
<details>
<summary>Changes</summary>
A SCEVWrapPredicate A implies B, if
* they have the same flag,
* both steps are positive and
* B's start and step are ULE/SLE (for NSUW/NSSW) than A's.
See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides).
Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future.
---
Full diff: https://github.com/llvm/llvm-project/pull/118184.diff
4 Files Affected:
- (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+7-6)
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+70-22)
- (modified) llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll (+5-7)
- (modified) llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll (+2-4)
``````````diff
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 885c5985f9d23a..27df25cbf2b7d2 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
virtual bool isAlwaysTrue() const = 0;
/// Returns true if this predicate implies \p N.
- virtual bool implies(const SCEVPredicate *N) const = 0;
+ virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;
/// Prints a textual representation of this predicate with an indentation of
/// \p Depth.
@@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
const SCEV *LHS, const SCEV *RHS);
/// Implementation of the SCEVPredicate interface
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;
@@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {
/// Implementation of the SCEVPredicate interface
const SCEVAddRecExpr *getExpr() const;
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;
@@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
SmallVector<const SCEVPredicate *, 16> Preds;
/// Adds a predicate to this union.
- void add(const SCEVPredicate *N);
+ void add(const SCEVPredicate *N, ScalarEvolution &SE);
public:
- SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
+ SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE);
ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }
/// Implementation of the SCEVPredicate interface
bool isAlwaysTrue() const override;
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth) const override;
/// We estimate the complexity of a union predicate as the size number of
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c3f296b9ff3347..a98e53450b9520 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5725,8 +5725,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
+ if (Expr1 != Expr2 &&
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
return false;
return true;
};
@@ -14823,7 +14824,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
- return Pred && Pred->implies(P);
+ return Pred && Pred->implies(P, SE);
}
NewPreds->push_back(P);
return true;
@@ -14904,7 +14905,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}
-bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
+bool SCEVComparePredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
if (!Op)
@@ -14934,10 +14936,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
-bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+ if (!Op)
+ return false;
+
+ if (setFlags(Flags, Op->Flags) != Flags)
+ return false;
+
+ if (Op->AR == AR)
+ return true;
+
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
+ Flags != SCEVWrapPredicate::IncrementNUSW)
+ return false;
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+
+ // If both steps are positive, this implies N, if N's start and step are
+ // ULE/SLE (for NSUW/NSSW) than this'.
+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
+ const SCEV *OpStart = Op->AR->getStart();
+ const SCEV *Start = AR->getStart();
+ if (SE.getTypeSizeInBits(Step->getType()) >
+ SE.getTypeSizeInBits(OpStep->getType())) {
+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
+ } else {
+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
+ }
+ if (SE.getTypeSizeInBits(Start->getType()) >
+ SE.getTypeSizeInBits(OpStart->getType())) {
+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
+ : SE.getSignExtendExpr(OpStart, Start->getType());
+ } else {
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
+ }
+
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
+ SE.isKnownPredicate(Pred, OpStart, Start);
+ }
+ return false;
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -14981,10 +15025,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}
/// Union predicates don't get cached so create a dummy set ID for it.
-SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
+SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE)
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
- add(P);
+ add(P, SE);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
@@ -14992,13 +15037,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const {
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}
-bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
- return all_of(Set->Preds,
- [this](const SCEVPredicate *I) { return this->implies(I); });
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
+ return this->implies(I, SE);
+ });
return any_of(Preds,
- [N](const SCEVPredicate *I) { return I->implies(N); });
+ [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
}
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
@@ -15006,15 +15053,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
Pred->print(OS, Depth);
}
-void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
- add(Pred);
+ add(Pred, SE);
return;
}
// Only add predicate if it is not already implied by this union predicate.
- if (!implies(N))
+ if (!implies(N, SE))
Preds.push_back(N);
}
@@ -15022,7 +15069,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
}
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15086,12 +15133,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
- if (Preds->implies(&Pred))
+ if (Preds->implies(&Pred, SE))
return;
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
NewPreds.push_back(&Pred);
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
updateGeneration();
}
@@ -15158,9 +15205,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
+ SE)),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
index 6dbb4a0c0129a6..ae10ab841420fd 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
@@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
; CHECK-NEXT: Run-time memory checks:
; CHECK-NEXT: Check 0:
; CHECK-NEXT: Comparing group
-; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
-; CHECK-NEXT: Against group
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
+; CHECK-NEXT: Against group
+; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
; CHECK-NEXT: Grouped accesses:
; CHECK-NEXT: Group
-; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
-; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
-; CHECK-NEXT: Group
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
+; CHECK-NEXT: Group
+; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
+; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
; CHECK: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
-; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
; CHECK: Expressions re-written:
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
@@ -85,7 +84,6 @@ exit:
; CHECK: Memory dependences are safe
; CHECK: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
-; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
define void @test2(i64 %x, ptr %a) {
entry:
br label %for.body
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
index 1a07805c2614f8..4f595b44ae5fde 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
@@ -3,7 +3,7 @@
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
-; FIXME: {0,+,3} implies {0,+,2}.
+; {0,+,3} [nssw] implies {0,+,2} [nssw]
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
; CHECK-NEXT: loop:
@@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
-; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
@@ -59,7 +58,7 @@ exit:
ret void
}
-; FIXME: {2,+,2} implies {0,+,2}.
+; {2,+,2} [nssw] implies {0,+,2} [nssw].
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
; CHECK-NEXT: loop:
@@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
-; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
``````````
</details>
https://github.com/llvm/llvm-project/pull/118184
More information about the llvm-commits
mailing list