[llvm] 4812e9a - [SCEV] Preserve flags in SCEVLoopGuardRewriter for add and mul. (#91472)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 3 05:25:59 PDT 2024
Author: Florian Hahn
Date: 2024-06-03T13:25:55+01:00
New Revision: 4812e9a487735c8e2d86070f335b3364f9847711
URL: https://github.com/llvm/llvm-project/commit/4812e9a487735c8e2d86070f335b3364f9847711
DIFF: https://github.com/llvm/llvm-project/commit/4812e9a487735c8e2d86070f335b3364f9847711.diff
LOG: [SCEV] Preserve flags in SCEVLoopGuardRewriter for add and mul. (#91472)
SCEVLoopGuardRewriter only replaces operands with equivalent values, so
we should be able to transfer the flags from the original expression.
PR: https://github.com/llvm/llvm-project/pull/91472
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e46d7183a2a35..3b9aa9ab623f8 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15030,10 +15030,18 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> ⤅
+ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
+
public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
- DenseMap<const SCEV *, const SCEV *> &M)
- : SCEVRewriteVisitor(SE), Map(M) {}
+ DenseMap<const SCEV *, const SCEV *> &M,
+ bool PreserveNUW, bool PreserveNSW)
+ : SCEVRewriteVisitor(SE), Map(M) {
+ if (PreserveNUW)
+ FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
+ if (PreserveNSW)
+ FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
+ }
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
@@ -15089,6 +15097,36 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
return I->second;
}
+
+ const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+ SmallVector<const SCEV *, 2> Operands;
+ bool Changed = false;
+ for (const auto *Op : Expr->operands()) {
+ Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+ Changed |= Op != Operands.back();
+ }
+ // We are only replacing operands with equivalent values, so transfer the
+ // flags from the original expression.
+ return !Changed
+ ? Expr
+ : SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
+ Expr->getNoWrapFlags(), FlagMask));
+ }
+
+ const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+ SmallVector<const SCEV *, 2> Operands;
+ bool Changed = false;
+ for (const auto *Op : Expr->operands()) {
+ Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+ Changed |= Op != Operands.back();
+ }
+ // We are only replacing operands with equivalent values, so transfer the
+ // flags from the original expression.
+ return !Changed
+ ? Expr
+ : SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
+ Expr->getNoWrapFlags(), FlagMask));
+ }
};
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15503,6 +15541,17 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (RewriteMap.empty())
return Expr;
+ // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
+ // the replacement expressions are contained in the ranges of the replaced
+ // expressions.
+ bool PreserveNUW = true;
+ bool PreserveNSW = true;
+ for (const SCEV *Expr : ExprsToRewrite) {
+ const SCEV *RewriteTo = RewriteMap[Expr];
+ PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
+ PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
+ }
+
// Now that all rewrite information is collect, rewrite the collected
// expressions with the information in the map. This applies information to
// sub-expressions.
@@ -15510,11 +15559,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
for (const SCEV *Expr : ExprsToRewrite) {
const SCEV *RewriteTo = RewriteMap[Expr];
RewriteMap.erase(Expr);
- SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
+ PreserveNSW);
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
}
}
-
- SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
return Rewriter.visit(Expr);
}
diff --git a/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
index da4487ce9cd48..3c3748a6a5f02 100644
--- a/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
@@ -77,13 +77,13 @@ define void @rewrite_preserve_add_nsw(i32 %a) {
; CHECK-NEXT: %add = add nsw i32 %a, 4
; CHECK-NEXT: --> (4 + %a)<nsw> U: [-2147483644,-2147483648) S: [-2147483644,-2147483648)
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
-; CHECK-NEXT: --> {0,+,1}<nuw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (0 smax (4 + %a)<nsw>) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {0,+,1}<nuw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (4 + %a)<nsw> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = add i32 %iv, 1
-; CHECK-NEXT: --> {1,+,1}<nuw><%loop> U: [1,-2147483647) S: [1,-2147483647) Exits: (1 + (0 smax (4 + %a)<nsw>))<nuw> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {1,+,1}<nuw><%loop> U: [1,-2147483647) S: [1,-2147483647) Exits: (5 + %a) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @rewrite_preserve_add_nsw
-; CHECK-NEXT: Loop %loop: backedge-taken count is (0 smax (4 + %a)<nsw>)
+; CHECK-NEXT: Loop %loop: backedge-taken count is (4 + %a)<nsw>
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 2147483647
-; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (0 smax (4 + %a)<nsw>)
+; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (4 + %a)<nsw>
; CHECK-NEXT: Loop %loop: Trip multiple is 1
;
entry:
diff --git a/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll b/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
index f86639ea4c50f..945eb82e18595 100644
--- a/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
+++ b/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
@@ -12,8 +12,7 @@ define void @rewrite_preserve_add_nsw(i32 %a) {
; CHECK-NEXT: [[PRE:%.*]] = icmp sgt i32 [[A]], -4
; CHECK-NEXT: br i1 [[PRE]], label [[LOOP_PREHEADER:%.*]], label [[EXIT:%.*]]
; CHECK: loop.preheader:
-; CHECK-NEXT: [[SMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD]], i32 0)
-; CHECK-NEXT: [[TMP0:%.*]] = add nuw i32 [[SMAX]], 1
+; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[A]], 5
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], [[LOOP]] ], [ 0, [[LOOP_PREHEADER]] ]
More information about the llvm-commits
mailing list