[llvm] b7aec4f - [SCEV] Support rewriting ZExt expressions with loop guard info.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 16 03:16:24 PST 2021


Author: Florian Hahn
Date: 2021-11-16T11:16:07Z
New Revision: b7aec4f08e5e231cda3366af961e227e18c5587b

URL: https://github.com/llvm/llvm-project/commit/b7aec4f08e5e231cda3366af961e227e18c5587b
DIFF: https://github.com/llvm/llvm-project/commit/b7aec4f08e5e231cda3366af961e227e18c5587b.diff

LOG: [SCEV] Support rewriting ZExt expressions with loop guard info.

So far, applying loop guard information has been restricted to
SCEVUnknown. In a few cases, like PR40961 and PR52464, this leads to
SCEV failing to determine tight upper bounds for the backedge taken
count.

This patch adjusts SCEVLoopGuardRewriter and applyLoopGuards to support
re-writing ZExt expressions.

This is a first step towards fixing  PR40961 and PR52464.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D113577

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 5bdbe4029eb42..e8556c2094353 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13694,7 +13694,8 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
 /// in the map. It skips AddRecExpr because we cannot guarantee that the
 /// replacement is loop invariant in the loop of the AddRec.
 ///
-/// At the moment only rewriting SCEVUnknown is supported.
+/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
+/// supported.
 class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
   const DenseMap<const SCEV *, const SCEV *> ⤅
 
@@ -13711,9 +13712,18 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
       return Expr;
     return I->second;
   }
+
+  const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+    auto I = Map.find(Expr);
+    if (I == Map.end())
+      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
+          Expr);
+    return I->second;
+  }
 };
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+  SmallVector<const SCEV *> ExprsToRewrite;
   auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
                               const SCEV *RHS,
                               DenseMap<const SCEV *, const SCEV *>
@@ -13736,6 +13746,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
         if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
           auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
           RewriteMap[LHSUnknown] = Multiple;
+          ExprsToRewrite.push_back(LHSUnknown);
           return;
         }
       }
@@ -13749,7 +13760,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     // Check for a condition of the form (-C1 + X < C2).  InstCombine will
     // create this form when combining two checks of the form (X u< C2 + C1) and
     // (X >=u C1).
-    auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap]() {
+    auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
+                                 &ExprsToRewrite]() {
       auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
       if (!AddExpr || AddExpr->getNumOperands() != 2)
         return false;
@@ -13772,21 +13784,35 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       RewriteMap[LHSUnknown] = getUMaxExpr(
           getConstant(ExactRegion.getUnsignedMin()),
           getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
+      ExprsToRewrite.push_back(LHSUnknown);
       return true;
     };
     if (MatchRangeCheckIdiom())
       return;
 
-    // For now, limit to conditions that provide information about unknown
-    // expressions. RHS also cannot contain add recurrences.
-    auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
-    if (!LHSUnknown || containsAddRecurrence(RHS))
+    // If RHS is SCEVUnknown, make sure the information is applied to it.
+    if (isa<SCEVUnknown>(RHS)) {
+      std::swap(LHS, RHS);
+      Predicate = CmpInst::getSwappedPredicate(Predicate);
+    }
+    // If LHS is a constant, apply information to the other expression.
+    if (isa<SCEVConstant>(LHS)) {
+      std::swap(LHS, RHS);
+      Predicate = CmpInst::getSwappedPredicate(Predicate);
+    }
+    // Do not apply information for constants or if RHS contains an AddRec.
+    if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
+      return;
+
+    // Limit to expressions that can be rewritten.
+    if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
       return;
 
     // Check whether LHS has already been rewritten. In that case we want to
     // chain further rewrites onto the already rewritten value.
     auto I = RewriteMap.find(LHS);
     const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
+
     const SCEV *RewrittenRHS = nullptr;
     switch (Predicate) {
     case CmpInst::ICMP_ULT:
@@ -13830,8 +13856,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       break;
     }
 
-    if (RewrittenRHS)
+    if (RewrittenRHS) {
       RewriteMap[LHS] = RewrittenRHS;
+      if (LHS == RewrittenLHS)
+        ExprsToRewrite.push_back(LHS);
+    }
   };
   // Starting at the loop predecessor, climb up the predecessor chain, as long
   // as there are predecessors that can be found that have unique successors
@@ -13887,6 +13916,19 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
 
   if (RewriteMap.empty())
     return Expr;
+
+  // Now that all rewrite information is collect, rewrite the collected
+  // expressions with the information in the map. This applies information to
+  // sub-expressions.
+  if (ExprsToRewrite.size() > 1) {
+    for (const SCEV *Expr : ExprsToRewrite) {
+      const SCEV *RewriteTo = RewriteMap[Expr];
+      RewriteMap.erase(Expr);
+      SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+      RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
+    }
+  }
+
   SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
   return Rewriter.visit(Expr);
 }

diff  --git a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
index cceb25d5cff6a..de8ca288a9bd0 100644
--- a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
+++ b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
@@ -7,7 +7,7 @@
 define void @rewrite_zext(i32 %n) {
 ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext
 ; CHECK-NEXT:  Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT:  Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT:  Loop %loop: max backedge-taken count is 2
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
 ; CHECK-NEXT:  Predicates:
 ; CHECK:        Loop %loop: Trip multiple is 1
@@ -36,7 +36,7 @@ exit:
 define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) {
 ; CHECK-LABEL:  Determining loop execution counts for: @rewrite_zext_min_max
 ; CHECK-NEXT:   Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
-; CHECK-NEXT:   Loop %loop: max backedge-taken count is 4611686018427387903
+; CHECK-NEXT:   Loop %loop: max backedge-taken count is 3
 ; CHECK-NEXT:   Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
 ; CHECK-NEXT:   Predicates:
 ; CHECK:         Loop %loop: Trip multiple is 1
@@ -134,7 +134,7 @@ exit:
 define void @rewrite_zext_and_base_1(i32 %n) {
 ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
 ; CHECK-NEXT:  Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT:  Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT:  Loop %loop: max backedge-taken count is 3
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
 ; CHECK-NEXT:  Predicates:
 ; CHECK:        Loop %loop: Trip multiple is 1
@@ -168,7 +168,7 @@ exit:
 define void @rewrite_zext_and_base_2(i32 %n) {
 ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
 ; CHECK-NEXT:  Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT:  Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT:  Loop %loop: max backedge-taken count is 3
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
 ; CHECK-NEXT:  Predicates:
 ; CHECK:        Loop %loop: Trip multiple is 1


        


More information about the llvm-commits mailing list