[llvm] [SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. (PR #118184)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 16 06:02:45 PST 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/118184
>From 632fe58f2b39972659878dd9bf2ac865f75a6880 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 29 Nov 2024 20:30:45 +0000
Subject: [PATCH 1/2] [SCEV] Use Step and Start to check if SCEVWrapPredicate
is implied.
A SCEVWrapPredicate A implies B, if
* they have the same flag,
* both steps are positive and
* B's start and step are ULE/SLE (for NSUW/NSSW) than A's.
See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants
as strides, second pair with variable strides).
Note that this is limited to steps of the same size, due to NSUW having
slightly different semantics than regular NUW. We should be able to
remove this restriction for NSSW (which matches NSW) in the future.
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 13 +--
llvm/lib/Analysis/ScalarEvolution.cpp | 92 ++++++++++++++-----
.../memcheck-wrapping-pointers.ll | 12 +--
.../nssw-predicate-implied.ll | 6 +-
4 files changed, 84 insertions(+), 39 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index de74524c4b6fe4..7879622473ad8b 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
virtual bool isAlwaysTrue() const = 0;
/// Returns true if this predicate implies \p N.
- virtual bool implies(const SCEVPredicate *N) const = 0;
+ virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;
/// Prints a textual representation of this predicate with an indentation of
/// \p Depth.
@@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
const SCEV *LHS, const SCEV *RHS);
/// Implementation of the SCEVPredicate interface
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;
@@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {
/// Implementation of the SCEVPredicate interface
const SCEVAddRecExpr *getExpr() const;
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;
@@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
SmallVector<const SCEVPredicate *, 16> Preds;
/// Adds a predicate to this union.
- void add(const SCEVPredicate *N);
+ void add(const SCEVPredicate *N, ScalarEvolution &SE);
public:
- SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
+ SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE);
ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }
/// Implementation of the SCEVPredicate interface
bool isAlwaysTrue() const override;
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth) const override;
/// We estimate the complexity of a union predicate as the size number of
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e18133971f5bf0..decf55003033c5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
+ if (Expr1 != Expr2 &&
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
return false;
return true;
};
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
- return Pred && Pred->implies(P);
+ return Pred && Pred->implies(P, SE);
}
NewPreds->push_back(P);
return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}
-bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
+bool SCEVComparePredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
if (!Op)
@@ -14968,10 +14970,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
-bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+ if (!Op)
+ return false;
+
+ if (setFlags(Flags, Op->Flags) != Flags)
+ return false;
+
+ if (Op->AR == AR)
+ return true;
+
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
+ Flags != SCEVWrapPredicate::IncrementNUSW)
+ return false;
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+
+ // If both steps are positive, this implies N, if N's start and step are
+ // ULE/SLE (for NSUW/NSSW) than this'.
+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
+ const SCEV *OpStart = Op->AR->getStart();
+ const SCEV *Start = AR->getStart();
+ if (SE.getTypeSizeInBits(Step->getType()) >
+ SE.getTypeSizeInBits(OpStep->getType())) {
+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
+ } else {
+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
+ }
+ if (SE.getTypeSizeInBits(Start->getType()) >
+ SE.getTypeSizeInBits(OpStart->getType())) {
+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
+ : SE.getSignExtendExpr(OpStart, Start->getType());
+ } else {
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
+ }
+
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
+ SE.isKnownPredicate(Pred, OpStart, Start);
+ }
+ return false;
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,10 +15059,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}
/// Union predicates don't get cached so create a dummy set ID for it.
-SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
+SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE)
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
- add(P);
+ add(P, SE);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
@@ -15026,13 +15071,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const {
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}
-bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
- return all_of(Set->Preds,
- [this](const SCEVPredicate *I) { return this->implies(I); });
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
+ return this->implies(I, SE);
+ });
return any_of(Preds,
- [N](const SCEVPredicate *I) { return I->implies(N); });
+ [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
}
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
@@ -15040,15 +15087,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
Pred->print(OS, Depth);
}
-void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
- add(Pred);
+ add(Pred, SE);
return;
}
// Only add predicate if it is not already implied by this union predicate.
- if (!implies(N))
+ if (!implies(N, SE))
Preds.push_back(N);
}
@@ -15056,7 +15103,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
}
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15167,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
- if (Preds->implies(&Pred))
+ if (Preds->implies(&Pred, SE))
return;
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
NewPreds.push_back(&Pred);
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
updateGeneration();
}
@@ -15192,9 +15239,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
+ SE)),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
index 6dbb4a0c0129a6..ae10ab841420fd 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
@@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
; CHECK-NEXT: Run-time memory checks:
; CHECK-NEXT: Check 0:
; CHECK-NEXT: Comparing group
-; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
-; CHECK-NEXT: Against group
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
+; CHECK-NEXT: Against group
+; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
; CHECK-NEXT: Grouped accesses:
; CHECK-NEXT: Group
-; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
-; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
-; CHECK-NEXT: Group
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
+; CHECK-NEXT: Group
+; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
+; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
; CHECK: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
-; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
; CHECK: Expressions re-written:
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
@@ -85,7 +84,6 @@ exit:
; CHECK: Memory dependences are safe
; CHECK: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
-; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
define void @test2(i64 %x, ptr %a) {
entry:
br label %for.body
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
index 1a07805c2614f8..4f595b44ae5fde 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
@@ -3,7 +3,7 @@
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
-; FIXME: {0,+,3} implies {0,+,2}.
+; {0,+,3} [nssw] implies {0,+,2} [nssw]
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
; CHECK-NEXT: loop:
@@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
-; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
@@ -59,7 +58,7 @@ exit:
ret void
}
-; FIXME: {2,+,2} implies {0,+,2}.
+; {2,+,2} [nssw] implies {0,+,2} [nssw].
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
; CHECK-NEXT: loop:
@@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
-; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
>From 4a10b2b932077fa519b2f7d38d6f0a9258bd7915 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 16 Dec 2024 14:00:28 +0000
Subject: [PATCH 2/2] !fixup adress comments, reorder code
---
llvm/lib/Analysis/ScalarEvolution.cpp | 44 ++++++++++-----------------
1 file changed, 16 insertions(+), 28 deletions(-)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index decf55003033c5..e2c2500052e7d6 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -14973,10 +14973,7 @@ const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
- if (!Op)
- return false;
-
- if (setFlags(Flags, Op->Flags) != Flags)
+ if (!Op || setFlags(Flags, Op->Flags) != Flags)
return false;
if (Op->AR == AR)
@@ -14986,36 +14983,27 @@ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
Flags != SCEVWrapPredicate::IncrementNUSW)
return false;
- bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
const SCEV *Step = AR->getStepRecurrence(SE);
const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+ if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
+ return false;
// If both steps are positive, this implies N, if N's start and step are
// ULE/SLE (for NSUW/NSSW) than this'.
- if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
- const SCEV *OpStart = Op->AR->getStart();
- const SCEV *Start = AR->getStart();
- if (SE.getTypeSizeInBits(Step->getType()) >
- SE.getTypeSizeInBits(OpStep->getType())) {
- OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
- } else {
- Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
- : SE.getNoopOrSignExtend(Step, OpStep->getType());
- }
- if (SE.getTypeSizeInBits(Start->getType()) >
- SE.getTypeSizeInBits(OpStart->getType())) {
- OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
- : SE.getSignExtendExpr(OpStart, Start->getType());
- } else {
- Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
- : SE.getNoopOrSignExtend(Start, OpStart->getType());
- }
+ Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
+ Step = SE.getNoopOrZeroExtend(Step, WiderTy);
+ OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
- CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
- return SE.isKnownPredicate(Pred, OpStep, Step) &&
- SE.isKnownPredicate(Pred, OpStart, Start);
- }
- return false;
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+ const SCEV *OpStart = Op->AR->getStart();
+ const SCEV *Start = AR->getStart();
+ OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
+ : SE.getNoopOrSignExtend(OpStart, WiderTy);
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
+ : SE.getNoopOrSignExtend(Start, WiderTy);
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
+ SE.isKnownPredicate(Pred, OpStart, Start);
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
More information about the llvm-commits
mailing list