[llvm] [SCEV] Preserve flags in SCEVLoopGuardRewriter for add and mul. (PR #91472)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri May 10 11:23:49 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/91472

>From 90a28608da5e74c17ea1d3f2982429c35a1fe76d Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 8 May 2024 14:41:20 +0100
Subject: [PATCH 1/2] [SCEV] Preserve flags in SCEVLoopGuardRewriter for add
 and mul.

SCEVLoopGuardRewriter only replaces operands with equivalent values, so we
should be able to transfer the flags from the original expression.
---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 24 +++++++++++++++++++
 .../backedge-taken-count-guard-info.ll        |  8 +++----
 ...count-expansion-loop-guard-preserve-nsw.ll |  3 +--
 3 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 93f885c5d5ad8..3666a7d4b9b22 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15034,6 +15034,30 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
       return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
     return I->second;
   }
+
+  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+    SmallVector<const SCEV *, 2> Operands;
+    bool Changed = false;
+    for (const auto *Op : Expr->operands()) {
+      Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+      Changed |= Op != Operands.back();
+    }
+    // We are only replacing operands with equivalent values, so transfer the
+    // flags from the original expression.
+    return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
+  }
+
+  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+    SmallVector<const SCEV *, 2> Operands;
+    bool Changed = false;
+    for (const auto *Op : Expr->operands()) {
+      Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+      Changed |= Op != Operands.back();
+    }
+    // We are only replacing operands with equivalent values, so transfer the
+    // flags from the original expression.
+    return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
+  }
 };
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
diff --git a/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
index da4487ce9cd48..3c3748a6a5f02 100644
--- a/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/backedge-taken-count-guard-info.ll
@@ -77,13 +77,13 @@ define void @rewrite_preserve_add_nsw(i32 %a) {
 ; CHECK-NEXT:    %add = add nsw i32 %a, 4
 ; CHECK-NEXT:    --> (4 + %a)<nsw> U: [-2147483644,-2147483648) S: [-2147483644,-2147483648)
 ; CHECK-NEXT:    %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
-; CHECK-NEXT:    --> {0,+,1}<nuw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (0 smax (4 + %a)<nsw>) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT:    --> {0,+,1}<nuw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (4 + %a)<nsw> LoopDispositions: { %loop: Computable }
 ; CHECK-NEXT:    %iv.next = add i32 %iv, 1
-; CHECK-NEXT:    --> {1,+,1}<nuw><%loop> U: [1,-2147483647) S: [1,-2147483647) Exits: (1 + (0 smax (4 + %a)<nsw>))<nuw> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT:    --> {1,+,1}<nuw><%loop> U: [1,-2147483647) S: [1,-2147483647) Exits: (5 + %a) LoopDispositions: { %loop: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @rewrite_preserve_add_nsw
-; CHECK-NEXT:  Loop %loop: backedge-taken count is (0 smax (4 + %a)<nsw>)
+; CHECK-NEXT:  Loop %loop: backedge-taken count is (4 + %a)<nsw>
 ; CHECK-NEXT:  Loop %loop: constant max backedge-taken count is i32 2147483647
-; CHECK-NEXT:  Loop %loop: symbolic max backedge-taken count is (0 smax (4 + %a)<nsw>)
+; CHECK-NEXT:  Loop %loop: symbolic max backedge-taken count is (4 + %a)<nsw>
 ; CHECK-NEXT:  Loop %loop: Trip multiple is 1
 ;
 entry:
diff --git a/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll b/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
index f86639ea4c50f..945eb82e18595 100644
--- a/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
+++ b/llvm/test/Transforms/IndVarSimplify/trip-count-expansion-loop-guard-preserve-nsw.ll
@@ -12,8 +12,7 @@ define void @rewrite_preserve_add_nsw(i32 %a) {
 ; CHECK-NEXT:    [[PRE:%.*]] = icmp sgt i32 [[A]], -4
 ; CHECK-NEXT:    br i1 [[PRE]], label [[LOOP_PREHEADER:%.*]], label [[EXIT:%.*]]
 ; CHECK:       loop.preheader:
-; CHECK-NEXT:    [[SMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD]], i32 0)
-; CHECK-NEXT:    [[TMP0:%.*]] = add nuw i32 [[SMAX]], 1
+; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[A]], 5
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
 ; CHECK-NEXT:    [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], [[LOOP]] ], [ 0, [[LOOP_PREHEADER]] ]

>From 339f8b927c9583bd12e1b4fa18600d88e02ad0e4 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 10 May 2024 19:22:43 +0100
Subject: [PATCH 2/2] !fiuxp preserve flags only if ranges are contained.

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 39 ++++++++++++++++++++++-----
 1 file changed, 32 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 8e497460d4dfc..03b79b4d92bcf 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -14973,10 +14973,18 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
 class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
   const DenseMap<const SCEV *, const SCEV *> ⤅
 
+  SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
+
 public:
   SCEVLoopGuardRewriter(ScalarEvolution &SE,
-                        DenseMap<const SCEV *, const SCEV *> &M)
-      : SCEVRewriteVisitor(SE), Map(M) {}
+                        DenseMap<const SCEV *, const SCEV *> &M,
+                        bool PreserveNUW, bool PreserveNSW)
+      : SCEVRewriteVisitor(SE), Map(M) {
+    if (PreserveNUW)
+      FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
+    if (PreserveNSW)
+      FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
+  }
 
   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
 
@@ -15042,7 +15050,10 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
     }
     // We are only replacing operands with equivalent values, so transfer the
     // flags from the original expression.
-    return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
+    return !Changed
+               ? Expr
+               : SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
+                                             Expr->getNoWrapFlags(), FlagMask));
   }
 
   const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
@@ -15054,7 +15065,10 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
     }
     // We are only replacing operands with equivalent values, so transfer the
     // flags from the original expression.
-    return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
+    return !Changed
+               ? Expr
+               : SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
+                                             Expr->getNoWrapFlags(), FlagMask));
   }
 };
 
@@ -15470,6 +15484,17 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
   if (RewriteMap.empty())
     return Expr;
 
+  // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
+  // the replacement expressions are contained in the ranges of the replaced
+  // expressions.
+  bool PreserveNUW = true;
+  bool PreserveNSW = true;
+  for (const SCEV *Expr : ExprsToRewrite) {
+    const SCEV *RewriteTo = RewriteMap[Expr];
+    PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
+    PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
+  }
+
   // Now that all rewrite information is collect, rewrite the collected
   // expressions with the information in the map. This applies information to
   // sub-expressions.
@@ -15477,11 +15502,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     for (const SCEV *Expr : ExprsToRewrite) {
       const SCEV *RewriteTo = RewriteMap[Expr];
       RewriteMap.erase(Expr);
-      SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+      SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
+                                     PreserveNSW);
       RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
     }
   }
-
-  SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+  SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
   return Rewriter.visit(Expr);
 }



More information about the llvm-commits mailing list